mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-24 09:39:55 +00:00
Compare commits
24 Commits
refactor-c
...
sha-pinnin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b59191665 | ||
|
|
0358be2313 | ||
|
|
37052fd5bc | ||
|
|
454ff66518 | ||
|
|
6137a1fcc5 | ||
|
|
4955c345d5 | ||
|
|
9192b4f029 | ||
|
|
c784b02550 | ||
|
|
d250f92c43 | ||
|
|
80966ab1b0 | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
705f87fc20 | ||
|
|
3f91f49277 | ||
|
|
347c5bf317 | ||
|
|
22e2519d71 | ||
|
|
e916f12cca | ||
|
|
9ed2e2a5b4 | ||
|
|
2ccae7ec47 | ||
|
|
07e5450117 | ||
|
|
3f914090cb | ||
|
|
ea9fab4396 | ||
|
|
77b479286e | ||
|
|
ab2a8794e7 |
46
.github/actions/parse-semver/action.yml
vendored
Normal file
46
.github/actions/parse-semver/action.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Parse semver string
|
||||
description: Parse a refs/tags/vX.Y.Z[-prerelease][+build] ref into version components. Falls back to 0.0.0 when not run on a version tag.
|
||||
outputs:
|
||||
major:
|
||||
description: Major version number (e.g. "1" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.major }}
|
||||
minor:
|
||||
description: Minor version number (e.g. "2" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.minor }}
|
||||
patch:
|
||||
description: Patch version number (e.g. "3" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.patch }}
|
||||
prerelease:
|
||||
description: Prerelease identifier (e.g. "rc.1" from v1.2.3-rc.1), empty when absent.
|
||||
value: ${{ steps.parse.outputs.prerelease }}
|
||||
build:
|
||||
description: Build metadata (e.g. "build.7" from v1.2.3+build.7), empty when absent.
|
||||
value: ${{ steps.parse.outputs.build }}
|
||||
fullversion:
|
||||
description: MAJOR.MINOR.PATCH joined (e.g. "1.2.3").
|
||||
value: ${{ steps.parse.outputs.fullversion }}
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- id: parse
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
REF="${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}"
|
||||
if [[ "$REF" =~ ^refs/tags/v([0-9]+)\.([0-9]+)\.([0-9]+)(-([0-9A-Za-z.-]+))?(\+([0-9A-Za-z.-]+))?$ ]]; then
|
||||
MAJOR="${BASH_REMATCH[1]}"
|
||||
MINOR="${BASH_REMATCH[2]}"
|
||||
PATCH="${BASH_REMATCH[3]}"
|
||||
PRERELEASE="${BASH_REMATCH[5]}"
|
||||
BUILD="${BASH_REMATCH[7]}"
|
||||
else
|
||||
MAJOR=0; MINOR=0; PATCH=0; PRERELEASE=""; BUILD=""
|
||||
fi
|
||||
{
|
||||
echo "major=$MAJOR"
|
||||
echo "minor=$MINOR"
|
||||
echo "patch=$PATCH"
|
||||
echo "prerelease=$PRERELEASE"
|
||||
echo "build=$BUILD"
|
||||
echo "fullversion=$MAJOR.$MINOR.$PATCH"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
34
.github/dependabot.yml
vendored
Normal file
34
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
aws-sdk:
|
||||
patterns:
|
||||
- "github.com/aws/aws-sdk-go-v2/*"
|
||||
pion:
|
||||
patterns:
|
||||
- "github.com/pion/*"
|
||||
gorm:
|
||||
patterns:
|
||||
- "gorm.io/*"
|
||||
otel:
|
||||
patterns:
|
||||
- "go.opentelemetry.io/*"
|
||||
testcontainers:
|
||||
patterns:
|
||||
- "github.com/testcontainers/testcontainers-go/*"
|
||||
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,6 +12,7 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
105
.github/workflows/check-license-dependencies.yml
vendored
105
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
@@ -19,7 +19,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
@@ -56,55 +57,55 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
|
||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,11 +8,10 @@ jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
discourse-tags: releases
|
||||
|
||||
6
.github/workflows/git-town.yml
vendored
6
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
- "**"
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: git-town/action@670e1f4feb81fdef4226fc09deefe09018eb20d1 # v1.3.3
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
|
||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,16 +16,16 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -44,4 +44,3 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
19
.github/workflows/golang-test-freebsd.yml
vendored
19
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,20 +15,29 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
|
||||
118
.github/workflows/golang-test-linux.yml
vendored
118
.github/workflows/golang-test-linux.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -36,10 +36,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -113,14 +113,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -128,10 +128,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -158,14 +158,14 @@ jobs:
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -177,7 +177,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -231,10 +231,10 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,10 +246,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -277,14 +277,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -298,7 +298,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -324,14 +324,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -343,10 +343,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -370,19 +370,19 @@ jobs:
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -390,10 +390,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -410,7 +410,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -427,7 +427,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
@@ -437,13 +437,13 @@ jobs:
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -474,10 +474,10 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -485,10 +485,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -505,7 +505,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -529,13 +529,13 @@ jobs:
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -566,10 +566,10 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -577,10 +577,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -597,7 +597,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -623,20 +623,20 @@ jobs:
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -644,10 +644,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
26
.github/workflows/golang-test-windows.yml
vendored
26
.github/workflows/golang-test-windows.yml
vendored
@@ -18,10 +18,10 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -44,16 +44,22 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://pkgs.netbird.io/wintun/wintun-0.14.1.zip"
|
||||
$dest = "${env:downloadPath}\wintun.zip"
|
||||
$expected = "07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51"
|
||||
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 -o $dest $url
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "wintun.zip checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
|
||||
10
.github/workflows/golangci-lint.yml
vendored
10
.github/workflows/golangci-lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
@@ -38,13 +38,13 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
14
.github/workflows/mobile-build-validation.yml
vendored
14
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,23 +16,23 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -52,9 +52,9 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
|
||||
2
.github/workflows/proto-version-check.yml
vendored
2
.github/workflows/proto-version-check.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
|
||||
181
.github/workflows/release.yml
vendored
181
.github/workflows/release.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
@@ -51,19 +51,26 @@ jobs:
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -93,19 +100,19 @@ jobs:
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
@@ -124,26 +131,24 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -156,18 +161,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -191,7 +196,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -282,28 +287,28 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -314,27 +319,25 @@ jobs:
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -375,7 +378,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -404,7 +407,7 @@ jobs:
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -418,16 +421,16 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -441,7 +444,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -449,7 +452,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
@@ -474,27 +477,24 @@ jobs:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
@@ -514,29 +514,39 @@ jobs:
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://pkgs.netbird.io/wintun/wintun-0.14.1.zip"
|
||||
$dest = "${env:downloadPath}\wintun.zip"
|
||||
$expected = "07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51"
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 -o $dest $url
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "wintun.zip checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z"
|
||||
$dest = "${env:downloadPath}\mesa3d.7z"
|
||||
$expected = "71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9"
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 -o $dest $url
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "mesa3d.7z checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -547,35 +557,38 @@ jobs:
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 `
|
||||
-o "${{ github.workspace }}\envar_plugin.zip" `
|
||||
"https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip"
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 `
|
||||
-o "${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z" `
|
||||
"https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
shell: pwsh
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
run: |
|
||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
||||
Copy-Item -Destination $nsisPluginDir -Force
|
||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
@@ -592,7 +605,7 @@ jobs:
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
@@ -611,7 +624,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
@@ -703,7 +716,7 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
|
||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
|
||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
@@ -42,10 +42,10 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
22
.github/workflows/test-infrastructure-files.yml
vendored
22
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
- "infrastructure_files/**"
|
||||
- ".github/workflows/test-infrastructure-files.yml"
|
||||
- "management/cmd/**"
|
||||
- "signal/cmd/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -68,15 +68,15 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -139,8 +139,8 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
@@ -254,7 +254,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
paths:
|
||||
- 'shared/management/http/api/openapi.yml'
|
||||
- "shared/management/http/api/openapi.yml"
|
||||
|
||||
jobs:
|
||||
trigger_docs_api_update:
|
||||
@@ -13,10 +13,10 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger API pages generation
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: generate api pages
|
||||
repo: netbirdio/docs
|
||||
ref: "refs/heads/main"
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
11
.github/workflows/wasm-build-validation.yml
vendored
11
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,15 +19,15 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
@@ -42,9 +42,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
@@ -65,4 +65,3 @@ jobs:
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
153
README.md
153
README.md
@@ -1,147 +1,134 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||
</a>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||
</a>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
</strong>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
### Self-host NetBird (video)
|
||||
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|---|---|---|---|---|
|
||||
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||
| | | | | ✓ OpenWRT |
|
||||
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||
- A **public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
### Acknowledgements
|
||||
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||
|
||||
### Legal
|
||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -43,16 +43,16 @@ func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -52,9 +52,10 @@ func (m *externalChainMonitor) start() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
m.done = make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
m.done = done
|
||||
|
||||
go m.run(ctx)
|
||||
go m.run(ctx, done)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) stop() {
|
||||
@@ -72,8 +73,8 @@ func (m *externalChainMonitor) stop() {
|
||||
<-done
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
||||
defer close(m.done)
|
||||
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||
defer close(done)
|
||||
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: externalMonitorInitInterval,
|
||||
|
||||
@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
|
||||
@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
|
||||
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
|
||||
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
|
||||
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
|
||||
Firewall Rules (Linux only)
|
||||
The bundle includes two separate firewall rule files:
|
||||
The bundle includes the following firewall-related files:
|
||||
|
||||
iptables.txt:
|
||||
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
|
||||
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
|
||||
- Includes all tables (filter, nat, mangle, raw, security)
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
ip6tables.txt:
|
||||
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
|
||||
- Same table coverage and anonymization as iptables.txt
|
||||
- Omitted when ip6tables is not installed or no IPv6 rules are present
|
||||
|
||||
ipset.txt:
|
||||
- Output of 'ipset list' (family-agnostic)
|
||||
- IP addresses are anonymized; set names and types remain unchanged
|
||||
|
||||
nftables.txt:
|
||||
- Complete nftables ruleset obtained via 'nft -a list ruleset'
|
||||
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
|
||||
- Includes rule handle numbers and packet counters
|
||||
- All tables, chains, and rules are included
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
- All IP addresses are anonymized; chain/table names remain unchanged
|
||||
|
||||
sysctls.txt:
|
||||
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
|
||||
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
|
||||
- Interface names anonymized when --anonymize is set
|
||||
|
||||
IP Rules (Linux only)
|
||||
The ip_rules.txt file contains detailed IP routing rule information:
|
||||
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSysctls(); err != nil {
|
||||
log.Errorf("failed to add sysctls to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
|
||||
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
if g.anonymize {
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
|
||||
log.Warnf("Failed to add ipset output to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both iptables-save and verbose listing
|
||||
func collectIPTablesRules() (string, error) {
|
||||
var builder strings.Builder
|
||||
|
||||
saveOutput, err := collectIPTablesSave()
|
||||
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
|
||||
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
|
||||
rules, err := collectIPTablesRules(saveBin, listBin)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== iptables-save output ===\n")
|
||||
log.Warnf("Failed to collect %s rules: %v", listBin, err)
|
||||
return
|
||||
}
|
||||
if g.anonymize {
|
||||
rules = g.anonymizer.AnonymizeString(rules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
|
||||
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
|
||||
// Returns an error when neither command produced any output (e.g. the binary is missing),
|
||||
// so the caller can skip writing an empty file.
|
||||
func collectIPTablesRules(saveBin, listBin string) (string, error) {
|
||||
var builder strings.Builder
|
||||
var collected bool
|
||||
var firstErr error
|
||||
|
||||
saveOutput, err := runCommand(saveBin)
|
||||
switch {
|
||||
case err != nil:
|
||||
firstErr = err
|
||||
log.Warnf("Failed to collect %s output: %v", saveBin, err)
|
||||
case strings.TrimSpace(saveOutput) == "":
|
||||
log.Debugf("%s produced no output, skipping", saveBin)
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
|
||||
builder.WriteString(saveOutput)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== ipset list output ===\n")
|
||||
builder.WriteString(ipsetOutput)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
|
||||
builder.WriteString(listHeader)
|
||||
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
stats, err := getTableStatistics(table)
|
||||
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
builder.WriteString(stats)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
if !collected {
|
||||
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
|
||||
return ipsets, nil
|
||||
}
|
||||
|
||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||
func collectIPTablesSave() (string, error) {
|
||||
cmd := exec.Command("iptables-save")
|
||||
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
|
||||
func runCommand(name string, args ...string) (string, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
rules := stdout.String()
|
||||
if strings.TrimSpace(rules) == "" {
|
||||
return "", fmt.Errorf("no iptables rules found")
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// getTableStatistics gets verbose statistics for an entire table using iptables command
|
||||
func getTableStatistics(table string) (string, error) {
|
||||
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
|
||||
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
|
||||
return fmt.Sprintf("type-%v", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
log.Info("Collecting sysctls")
|
||||
content := collectSysctls()
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
|
||||
return fmt.Errorf("add sysctls to bundle: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectSysctls reads every sysctl that the netbird client may modify, plus
|
||||
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
|
||||
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
|
||||
func collectSysctls() string {
|
||||
var builder strings.Builder
|
||||
|
||||
writeSysctlGroup(&builder, "forwarding", []string{
|
||||
"net.ipv4.ip_forward",
|
||||
"net.ipv6.conf.all.forwarding",
|
||||
"net.ipv6.conf.default.forwarding",
|
||||
})
|
||||
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
|
||||
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
|
||||
writeSysctlGroup(&builder, "rp_filter", append(
|
||||
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
|
||||
listInterfaceSysctls("ipv4", "rp_filter")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "src_valid_mark", append(
|
||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "conntrack", []string{
|
||||
"net.netfilter.nf_conntrack_acct",
|
||||
"net.netfilter.nf_conntrack_tcp_loose",
|
||||
})
|
||||
writeSysctlGroup(&builder, "tcp", []string{
|
||||
"net.ipv4.tcp_tw_reuse",
|
||||
})
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
|
||||
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
|
||||
for _, key := range keys {
|
||||
value, err := readSysctl(key)
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
|
||||
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
|
||||
// (callers add those explicitly so they appear first).
|
||||
func listInterfaceSysctls(family, leaf string) []string {
|
||||
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if name == "all" || name == "default" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func readSysctl(key string) (string, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
value, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(value)), nil
|
||||
}
|
||||
|
||||
@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
|
||||
// IP rules are only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
// Sysctl collection is only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,10 @@ type hostManager interface {
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
string() string
|
||||
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
|
||||
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
|
||||
// hardcoded Quad9 on iOS, nil for noop / mock.
|
||||
getOriginalNameservers() []netip.Addr
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
|
||||
func (n noopHostConfigurator) string() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// androidHostManager is a noop on the OS side (Android's VPN service handles
|
||||
// DNS for us) but tracks the OS-reported resolver list pushed via
|
||||
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
|
||||
type androidHostManager struct {
|
||||
holder *hostsDNSHolder
|
||||
}
|
||||
|
||||
func newHostManager() (*androidHostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
|
||||
return &androidHostManager{holder: holder}, nil
|
||||
}
|
||||
|
||||
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
|
||||
func (a androidHostManager) string() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
|
||||
hosts := a.holder.get()
|
||||
out := make([]netip.Addr, 0, len(hosts))
|
||||
for ap := range hosts {
|
||||
out = append(out, ap.Addr())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
|
||||
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
|
||||
return []netip.Addr{
|
||||
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
|
||||
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
|
||||
}
|
||||
}
|
||||
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
jsonData, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -44,9 +45,11 @@ const (
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
|
||||
// Network interface DNS registration settings
|
||||
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
||||
@@ -67,10 +70,11 @@ const (
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
gpo: useGPO,
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
configurator.origNameservers = origNameservers
|
||||
|
||||
if err := configurator.configureInterface(); err != nil {
|
||||
log.Errorf("failed to configure interface settings: %v", err)
|
||||
}
|
||||
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
|
||||
// registry key except the WG adapter. v4 and v6 servers live in separate
|
||||
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
|
||||
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
var merr *multierror.Error
|
||||
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
|
||||
addrs, err := r.captureFromTcpipRoot(root)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
|
||||
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open key: %w", err)
|
||||
}
|
||||
defer closer(root)
|
||||
|
||||
guids, err := root.ReadSubKeyNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read subkeys: %w", err)
|
||||
}
|
||||
|
||||
var out []netip.Addr
|
||||
for _, guid := range guids {
|
||||
if strings.EqualFold(guid, r.guid) {
|
||||
continue
|
||||
}
|
||||
out = append(out, readInterfaceNameservers(rootPath, guid)...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
|
||||
keyPath := rootPath + "\\" + guid
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closer(k)
|
||||
|
||||
// Static NameServer wins over DhcpNameServer for actual resolution.
|
||||
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
|
||||
raw, _, err := k.GetStringValue(name)
|
||||
if err != nil || raw == "" {
|
||||
continue
|
||||
}
|
||||
if out := parseRegistryNameservers(raw); len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRegistryNameservers(raw string) []netip.Addr {
|
||||
var out []netip.Addr
|
||||
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
|
||||
addr, err := netip.ParseAddr(strings.TrimSpace(field))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
// Drop unzoned link-local: not routable without a scope id. If
|
||||
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
|
||||
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(r.origNameservers)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||
h.mutex.RLock()
|
||||
l := h.unprotectedDNSList
|
||||
|
||||
@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
func (m *MockServer) ProbeAvailability() {
|
||||
}
|
||||
|
||||
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
if m.UpdateServerConfigFunc != nil {
|
||||
return m.UpdateServerConfigFunc(domains)
|
||||
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// SetRouteSources mock implementation of SetRouteSources from Server interface
|
||||
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,15 @@ const (
|
||||
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
||||
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
||||
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
||||
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
|
||||
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
|
||||
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
|
||||
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
|
||||
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
|
||||
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
|
||||
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
|
||||
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
|
||||
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
|
||||
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
||||
networkManagerDbusIPv4Key = "ipv4"
|
||||
networkManagerDbusIPv6Key = "ipv6"
|
||||
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
|
||||
}
|
||||
|
||||
type networkManagerDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
|
||||
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
c := &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
origNameservers, err := c.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from NetworkManager: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
c.origNameservers = origNameservers
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads DNS servers from every NM device's
|
||||
// IP4Config / IP6Config except our WG device.
|
||||
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
devices, err := networkManagerListDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list devices: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
for _, dev := range devices {
|
||||
if dev == n.dbusLinkObject {
|
||||
continue
|
||||
}
|
||||
ifaceName := readNetworkManagerDeviceInterface(dev)
|
||||
for _, addr := range readNetworkManagerDeviceDNS(dev) {
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
// IP6Config.Nameservers is a byte slice without zone info;
|
||||
// reattach the device's interface name so a captured fe80::…
|
||||
// stays routable.
|
||||
if addr.IsLinkLocalUnicast() && ifaceName != "" {
|
||||
addr = addr.WithZone(ifaceName)
|
||||
}
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.Value().(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var devs []dbus.ObjectPath
|
||||
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devs, nil
|
||||
}
|
||||
|
||||
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
var out []netip.Addr
|
||||
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
|
||||
out = append(out, readIPv4ConfigDNS(path)...)
|
||||
}
|
||||
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
|
||||
out = append(out, readIPv6ConfigDNS(path)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
|
||||
v, err := obj.GetProperty(property)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path, ok := v.Value().(dbus.ObjectPath)
|
||||
if !ok || path == "/" {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
|
||||
// legacy uint32 Nameservers property.
|
||||
if out := readIPv4NameserverData(obj); len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
return readIPv4LegacyNameservers(obj)
|
||||
}
|
||||
|
||||
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
|
||||
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, ok := v.Value().([]map[string]dbus.Variant)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var out []netip.Addr
|
||||
for _, entry := range entries {
|
||||
addrVar, ok := entry["address"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s, ok := addrVar.Value().(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if a, err := netip.ParseAddr(s); err == nil {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
|
||||
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := v.Value().([]uint32)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]netip.Addr, 0, len(raw))
|
||||
for _, n := range raw {
|
||||
var b [4]byte
|
||||
binary.LittleEndian.PutUint32(b[:], n)
|
||||
out = append(out, netip.AddrFrom4(b))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := v.Value().([][]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]netip.Addr, 0, len(raw))
|
||||
for _, b := range raw {
|
||||
if a, ok := netip.AddrFromSlice(b); ok {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(n.origNameservers)
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
return newHostManager(s.hostsDNSHolder)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
@@ -31,8 +32,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -101,16 +104,17 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
upstreamServers: srvs,
|
||||
cancel: func() {},
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
{false, "domain1", false},
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||
Domains: []string{"domain1"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil, 0)
|
||||
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||
}
|
||||
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
|
||||
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
|
||||
// + androidHostManager). On non-android the desktop host manager replaces it
|
||||
// during Initialize and the assertion stops making sense. Skipped here until we
|
||||
// have an android CI runner.
|
||||
func skipUnlessAndroid(t *testing.T) {
|
||||
t.Helper()
|
||||
if runtime.GOOS != "android" {
|
||||
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
t.Helper()
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
@@ -1065,7 +1017,6 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
|
||||
// admin-defined nameserver groups targeting the same domain collapse into a
|
||||
// single handler with each group preserved as a sequential inner list.
|
||||
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgInterface,
|
||||
service: service,
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
|
||||
assert.Equal(t, "example.com", muxUpdates[0].domain)
|
||||
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
|
||||
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
|
||||
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
|
||||
assert.Equal(t, upstreamRace{
|
||||
netip.MustParseAddrPort("192.0.2.2:53"),
|
||||
netip.MustParseAddrPort("192.0.2.3:53"),
|
||||
}, handler.upstreamServers[1])
|
||||
}
|
||||
|
||||
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
|
||||
// (overlay route selected-but-no-active-peer) is intentionally NOT an
|
||||
// input to the evaluator anymore: the verdict drives the Enabled flag,
|
||||
// which must always reflect what we actually observed. Gate-aware event
|
||||
// suppression is tested separately in the projection test.
|
||||
//
|
||||
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
|
||||
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
|
||||
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
|
||||
// fresh-working → Unhealthy; otherwise Undecided.
|
||||
func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
now := time.Now()
|
||||
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
|
||||
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
|
||||
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
|
||||
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
|
||||
okThenFail := UpstreamHealth{
|
||||
LastOk: now.Add(-10 * time.Second),
|
||||
LastFail: now.Add(-1 * time.Second),
|
||||
LastErr: "timeout",
|
||||
}
|
||||
failThenOk := UpstreamHealth{
|
||||
LastOk: now.Add(-1 * time.Second),
|
||||
LastFail: now.Add(-10 * time.Second),
|
||||
LastErr: "timeout",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
health map[netip.AddrPort]UpstreamHealth
|
||||
servers []netip.AddrPort
|
||||
wantVerdict nsGroupVerdict
|
||||
wantErrSubst string
|
||||
}{
|
||||
{
|
||||
name: "no record, undecided",
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "fresh success, healthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "fresh failure, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "only stale success, undecided",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "only stale failure, undecided",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "both fresh, fail newer, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "both fresh, ok newer, healthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "two upstreams, one success wins",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: recentFail,
|
||||
b: recentOk,
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "two upstreams, one fail one unseen, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: recentFail,
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "two upstreams, all recent failures, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
|
||||
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "SERVFAIL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
|
||||
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
|
||||
if tc.wantErrSubst != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.wantErrSubst)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
health map[netip.AddrPort]UpstreamHealth
|
||||
}
|
||||
|
||||
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (h *healthStubHandler) Stop() {}
|
||||
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
|
||||
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
return h.health
|
||||
}
|
||||
|
||||
// TestProjection_SteadyStateIsSilent guards against duplicate events:
|
||||
// while a group stays Unhealthy tick after tick, only the first
|
||||
// Unhealthy transition may emit. Same for staying Healthy.
|
||||
func TestProjection_SteadyStateIsSilent(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "first fail emits warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.tick()
|
||||
fx.expectNoEvent("staying unhealthy must not re-emit")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "recovery on transition")
|
||||
|
||||
fx.tick()
|
||||
fx.tick()
|
||||
fx.expectNoEvent("staying healthy must not re-emit")
|
||||
}
|
||||
|
||||
// projTestFixture is the common setup for the projection tests: a
|
||||
// single-upstream group whose route classification the test can flip by
|
||||
// assigning to selected/active. Callers drive failures/successes by
|
||||
// mutating stub.health and calling refreshHealth.
|
||||
type projTestFixture struct {
|
||||
t *testing.T
|
||||
recorder *peer.Status
|
||||
events <-chan *proto.SystemEvent
|
||||
server *DefaultServer
|
||||
stub *healthStubHandler
|
||||
group *nbdns.NameServerGroup
|
||||
srv netip.AddrPort
|
||||
selected route.HAMap
|
||||
active route.HAMap
|
||||
}
|
||||
|
||||
func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
t.Helper()
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
|
||||
srv := netip.MustParseAddrPort("100.64.0.1:53")
|
||||
fx := &projTestFixture{
|
||||
t: t,
|
||||
recorder: recorder,
|
||||
events: sub.Events(),
|
||||
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
|
||||
srv: srv,
|
||||
group: &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
},
|
||||
}
|
||||
fx.server = &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
fx.server.mux.Unlock()
|
||||
return fx
|
||||
}
|
||||
|
||||
func (f *projTestFixture) setHealth(h UpstreamHealth) {
|
||||
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
|
||||
}
|
||||
|
||||
func (f *projTestFixture) tick() []peer.NSGroupState {
|
||||
f.server.refreshHealth()
|
||||
return f.recorder.GetDNSStates()
|
||||
}
|
||||
|
||||
func (f *projTestFixture) expectNoEvent(why string) {
|
||||
f.t.Helper()
|
||||
select {
|
||||
case evt := <-f.events:
|
||||
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
|
||||
f.t.Helper()
|
||||
select {
|
||||
case evt := <-f.events:
|
||||
assert.Contains(f.t, evt.Message, substr, why)
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
f.t.Fatalf("expected event (%s) with %q", why, substr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
|
||||
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
|
||||
|
||||
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
|
||||
// that is not inside any selected route (public DNS) fires the warning
|
||||
// on the first Unhealthy tick, no grace period.
|
||||
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled)
|
||||
fx.expectEvent("unreachable", "public DNS failure")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
|
||||
// the upstream is inside a selected route AND the route has a Connected
|
||||
// peer. Tunnel is up, failure is real, emit immediately.
|
||||
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled)
|
||||
fx.expectEvent("unreachable", "overlay + connected failure")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
|
||||
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
|
||||
// First tick: Unhealthy display, no warning. After the grace window
|
||||
// elapses with no recovery, the warning fires.
|
||||
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
|
||||
grace := 50 * time.Millisecond
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = grace
|
||||
fx.selected = overlayMapForTest
|
||||
// active stays nil: routed but not connected.
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
|
||||
fx.expectNoEvent("first fail tick within grace window")
|
||||
|
||||
time.Sleep(grace + 10*time.Millisecond)
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "warning after grace window")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
|
||||
// whose address is inside the WireGuard overlay range but is not
|
||||
// covered by any selected route (peer-to-peer DNS without an explicit
|
||||
// route). Until a peer reports Connected for that address, startup
|
||||
// failures must be held just like the routed case.
|
||||
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
|
||||
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
}
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-sub.Events():
|
||||
t.Fatalf("unexpected event during grace window: %+v", evt)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-sub.Events():
|
||||
assert.Contains(t, evt.Message, "unreachable")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected warning after grace window")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjection_StopClearsHealthState verifies that Stop wipes the
|
||||
// per-group projection state so a subsequent Start doesn't inherit
|
||||
// sticky flags (notably everHealthy) that would bypass the grace
|
||||
// window during the next peer handshake.
|
||||
func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
wgIface := &mocWGIface{}
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgIface,
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
currentConfigHash: ^uint64(0),
|
||||
}
|
||||
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
srv := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
server.healthProjectMu.Lock()
|
||||
p, ok := server.nsGroupProj[generateGroupKey(group)]
|
||||
server.healthProjectMu.Unlock()
|
||||
require.True(t, ok, "projection state should exist after tick")
|
||||
require.True(t, p.everHealthy, "tick with success must set everHealthy")
|
||||
|
||||
server.Stop()
|
||||
|
||||
server.healthProjectMu.Lock()
|
||||
cleared := server.nsGroupProj == nil
|
||||
server.healthProjectMu.Unlock()
|
||||
assert.True(t, cleared, "Stop must clear nsGroupProj")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
fx.selected = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectNoEvent("fail within grace, warning suppressed")
|
||||
|
||||
fx.active = overlayMapForTest
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.True(t, states[0].Enabled)
|
||||
fx.expectNoEvent("recovery without prior warning must not emit")
|
||||
}
|
||||
|
||||
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
|
||||
// whole design leans on: recovery events only appear when a warning
|
||||
// event was actually emitted for the current streak. A Healthy verdict
|
||||
// without a prior warning is silent, so the user never sees "recovered"
|
||||
// out of thin air.
|
||||
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.True(t, states[0].Enabled)
|
||||
fx.expectNoEvent("first healthy tick should not recover anything")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "public fail emits immediately")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "recovery follows real warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "second cycle warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "second cycle recovery")
|
||||
}
|
||||
|
||||
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
|
||||
// has ever been Healthy, subsequent failures skip the grace window even
|
||||
// if classification says "routed + not connected". The system has
|
||||
// proved it can work, so any new failure is real.
|
||||
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
|
||||
fx.server.warningDelayBase = time.Hour
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
// Establish "ever healthy".
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectNoEvent("first healthy tick")
|
||||
|
||||
// Peer drops. Query fails. Routed + not connected → normally grace,
|
||||
// but everHealthy flag bypasses it.
|
||||
fx.active = nil
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
|
||||
}
|
||||
|
||||
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
|
||||
// from the design discussion: once a group has been healthy, a brief
|
||||
// reconnect that produces a failing tick will fire warning + recovery.
|
||||
// This is by design: user-visible blips are accurate signal, not noise.
|
||||
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "blip warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "blip recovery")
|
||||
}
|
||||
|
||||
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
|
||||
// rule: a group with at least one public upstream is in the "immediate"
|
||||
// category regardless of the other upstreams' routing, because the
|
||||
// public one has no peer-startup excuse. Prevents public-DNS failures
|
||||
// from being hidden behind a routed sibling.
|
||||
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
events := sub.Events()
|
||||
|
||||
public := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
overlay := netip.MustParseAddrPort("100.64.0.1:53")
|
||||
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
}
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
|
||||
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
|
||||
},
|
||||
}
|
||||
stub := &healthStubHandler{
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
public: {LastFail: time.Now(), LastErr: "servfail"},
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-events:
|
||||
assert.Contains(t, evt.Message, "unreachable")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected immediate warning because group contains a public upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSLoopPrevention(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
|
||||
if tt.expectedHandlers > 0 {
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
||||
flat := handler.flatUpstreams()
|
||||
assert.Len(t, flat, len(tt.expectedServers))
|
||||
|
||||
if tt.shouldFilterOwnIP {
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
for _, upstream := range flat {
|
||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range tt.expectedServers {
|
||||
found := false
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
for _, upstream := range flat {
|
||||
if upstream.Addr() == expected {
|
||||
found = true
|
||||
break
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
@@ -40,10 +41,17 @@ const (
|
||||
)
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
ifaceName string
|
||||
wgIndex int
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
const (
|
||||
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
|
||||
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
|
||||
)
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
||||
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
||||
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
|
||||
|
||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||
|
||||
return &systemdDbusConfigurator{
|
||||
c := &systemdDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
wgIndex: iface.Index,
|
||||
}
|
||||
|
||||
origNameservers, err := c.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
c.origNameservers = origNameservers
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
|
||||
// every default-route link except our own WG link. Non-default-route links
|
||||
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
|
||||
// actually serve host queries.
|
||||
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list interfaces: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
for _, iface := range ifaces {
|
||||
if !s.isCandidateLink(iface) {
|
||||
continue
|
||||
}
|
||||
linkPath, err := getSystemdLinkPath(iface.Index)
|
||||
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
|
||||
continue
|
||||
}
|
||||
for _, addr := range readSystemdLinkDNS(linkPath) {
|
||||
addr = normalizeSystemdAddr(addr, iface.Name)
|
||||
if !addr.IsValid() {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
|
||||
if iface.Index == s.wgIndex {
|
||||
return false
|
||||
}
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
|
||||
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
|
||||
// Returns the zero Addr to signal "skip this entry".
|
||||
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
if addr.IsLinkLocalUnicast() {
|
||||
return addr.WithZone(ifaceName)
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dbus resolve1: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var p string
|
||||
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dbus.ObjectPath(p), nil
|
||||
}
|
||||
|
||||
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, ok := v.Value().(bool)
|
||||
return ok && b
|
||||
}
|
||||
|
||||
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, ok := v.Value().([][]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var out []netip.Addr
|
||||
for _, entry := range entries {
|
||||
if len(entry) < 2 {
|
||||
continue
|
||||
}
|
||||
raw, ok := entry[1].([]byte)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||
|
||||
@@ -1,3 +1,32 @@
|
||||
// Package dns implements the client-side DNS stack: listener/service on the
|
||||
// peer's tunnel address, handler chain that routes questions by domain and
|
||||
// priority, and upstream resolvers that forward what remains to configured
|
||||
// nameservers.
|
||||
//
|
||||
// # Upstream resolution and the race model
|
||||
//
|
||||
// When two or more nameserver groups target the same domain, DefaultServer
|
||||
// merges them into one upstream handler whose state is:
|
||||
//
|
||||
// upstreamResolverBase
|
||||
// └── upstreamServers []upstreamRace // one entry per source NS group
|
||||
// └── []netip.AddrPort // primary, fallback, ...
|
||||
//
|
||||
// Each source nameserver group contributes one upstreamRace. Within a race
|
||||
// upstreams are tried in order: the next is used only on failure (timeout,
|
||||
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
|
||||
// the walk. When more than one race exists, ServeDNS fans out one
|
||||
// goroutine per race and returns the first valid answer, cancelling the
|
||||
// rest. A handler with a single race skips the fan-out.
|
||||
//
|
||||
// # Health projection
|
||||
//
|
||||
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
|
||||
// periodically merges these snapshots across handlers and projects them
|
||||
// into peer.NSGroupState. There is no active probing: a group is marked
|
||||
// unhealthy only when every seen upstream has a recent failure and none
|
||||
// has a recent success. Healthy→unhealthy fires a single
|
||||
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
|
||||
package dns
|
||||
|
||||
import (
|
||||
@@ -11,11 +40,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -25,7 +51,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
@@ -67,15 +94,17 @@ const (
|
||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||
ClientTimeout = 5 * time.Second
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
|
||||
// within one race, so a slow primary can't eat the whole race budget.
|
||||
raceMaxTotalTimeout = 5 * time.Second
|
||||
// raceMinPerUpstreamTimeout is the floor applied when dividing
|
||||
// raceMaxTotalTimeout across upstreams within a race.
|
||||
raceMinPerUpstreamTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
@@ -84,6 +113,69 @@ const (
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type UpstreamResolver interface {
|
||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
// upstreamRace is an ordered list of upstreams derived from one configured
|
||||
// nameserver group. Order matters: the first upstream is tried first, the
|
||||
// second only on failure, and so on. Multiple upstreamRace values coexist
|
||||
// inside one resolver when overlapping nameserver groups target the same
|
||||
// domain; those races run in parallel and the first valid answer wins.
|
||||
type upstreamRace []netip.AddrPort
|
||||
|
||||
// UpstreamHealth is the last query-path outcome for a single upstream,
|
||||
// consumed by nameserver-group status projection.
|
||||
type UpstreamHealth struct {
|
||||
LastOk time.Time
|
||||
LastFail time.Time
|
||||
LastErr string
|
||||
}
|
||||
|
||||
type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []upstreamRace
|
||||
domain domain.Domain
|
||||
upstreamTimeout time.Duration
|
||||
|
||||
healthMu sync.RWMutex
|
||||
health map[netip.AddrPort]*UpstreamHealth
|
||||
|
||||
statusRecorder *peer.Status
|
||||
// selectedRoutes returns the current set of client routes the admin
|
||||
// has enabled. Called lazily from the query hot path when an upstream
|
||||
// might need a tunnel-bound client (iOS) and from health projection.
|
||||
selectedRoutes func() route.HAMap
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
type raceResult struct {
|
||||
msg *dns.Msg
|
||||
upstream netip.AddrPort
|
||||
protocol string
|
||||
ede string
|
||||
failures []upstreamFailure
|
||||
}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
@@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
@@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type UpstreamResolver interface {
|
||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []netip.AddrPort
|
||||
domain string
|
||||
disabled bool
|
||||
successCount atomic.Int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
wg sync.WaitGroup
|
||||
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
statusRecorder *peer.Status
|
||||
routeMatch func(netip.Addr) bool
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
statusRecorder: statusRecorder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: d,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("Upstream %s", u.upstreamServers)
|
||||
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
// ID returns the unique handler ID. Race groupings and within-race
|
||||
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
|
||||
// the same servers but with different semantics (serial fallback vs
|
||||
// parallel race), so their handlers must not collide.
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
for _, s := range servers {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
hash.Write([]byte(u.domain.PunycodeString() + ":"))
|
||||
for _, race := range u.upstreamServers {
|
||||
hash.Write([]byte("["))
|
||||
for _, s := range race {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
}
|
||||
hash.Write([]byte("]"))
|
||||
}
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
@@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
|
||||
u.cancel()
|
||||
}
|
||||
|
||||
u.mutex.Lock()
|
||||
u.wg.Wait()
|
||||
u.mutex.Unlock()
|
||||
// flatUpstreams is for logging and ID hashing only, not for dispatch.
|
||||
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
|
||||
var out []netip.AddrPort
|
||||
for _, g := range u.upstreamServers {
|
||||
out = append(out, g...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// setSelectedRoutes swaps the accessor used to classify overlay-routed
|
||||
// upstreams. Called when route sources are wired after the handler was
|
||||
// built (permanent / iOS constructors).
|
||||
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
|
||||
u.selectedRoutes = selected
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
|
||||
if len(servers) == 0 {
|
||||
return
|
||||
}
|
||||
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
@@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
minPerUpstream := 2 * time.Second
|
||||
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
||||
if scaledTimeout > minPerUpstream {
|
||||
timeout = scaledTimeout
|
||||
} else {
|
||||
timeout = minPerUpstream
|
||||
}
|
||||
groups := u.upstreamServers
|
||||
switch len(groups) {
|
||||
case 0:
|
||||
return false, nil
|
||||
case 1:
|
||||
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
|
||||
default:
|
||||
return u.raceAll(ctx, w, r, groups, logger)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
res := u.tryRace(ctx, r, group)
|
||||
if res.msg == nil {
|
||||
return false, res.failures
|
||||
}
|
||||
if res.ede != "" {
|
||||
resutil.SetMeta(w, "ede", res.ede)
|
||||
}
|
||||
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||
return true, res.failures
|
||||
}
|
||||
|
||||
// raceAll runs one worker per group in parallel, taking the first valid
|
||||
// answer and cancelling the rest.
|
||||
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
raceCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Buffer sized to len(groups) so workers never block on send, even
|
||||
// after the coordinator has returned.
|
||||
results := make(chan raceResult, len(groups))
|
||||
for _, g := range groups {
|
||||
// tryRace clones the request per attempt, so workers never share
|
||||
// a *dns.Msg and concurrent EDNS0 mutations can't race.
|
||||
go func(g upstreamRace) {
|
||||
results <- u.tryRace(raceCtx, r, g)
|
||||
}(g)
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
for range groups {
|
||||
select {
|
||||
case res := <-results:
|
||||
failures = append(failures, res.failures...)
|
||||
if res.msg != nil {
|
||||
if res.ede != "" {
|
||||
resutil.SetMeta(w, "ede", res.ede)
|
||||
}
|
||||
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||
return true, failures
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, failures
|
||||
}
|
||||
}
|
||||
return false, failures
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(group) > 1 {
|
||||
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
|
||||
// still honor raceMinPerUpstreamTimeout as a floor for correctness
|
||||
// on slow links, but the outer context ensures the combined walk
|
||||
// cannot exceed the cap regardless of group size.
|
||||
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range group {
|
||||
if ctx.Err() != nil {
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
// Clone the request per attempt: the exchange path mutates EDNS0
|
||||
// options in-place, so reusing the same *dns.Msg across sequential
|
||||
// upstreams would carry those mutations (e.g. a reduced UDP size)
|
||||
// into the next attempt.
|
||||
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
|
||||
if failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
continue
|
||||
}
|
||||
res.failures = failures
|
||||
return res
|
||||
}
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
|
||||
|
||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
||||
// failover for definitive answers like DNSSEC validation failures.
|
||||
// Operate on a copy so the inbound request is unchanged: a client that
|
||||
// did not advertise EDNS0 must not see an OPT in the response.
|
||||
// The caller already passed a per-attempt copy, so we can mutate r
|
||||
// directly; hadEdns reflects the original client request's state and
|
||||
// controls whether we strip the OPT from the response.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
reqUp := r
|
||||
if !hadEdns {
|
||||
reqUp = r.Copy()
|
||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
||||
r.SetEdns0(upstreamUDPSize(), false)
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
||||
}()
|
||||
startTime := time.Now()
|
||||
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
|
||||
if err != nil {
|
||||
return u.handleUpstreamError(err, upstream, startTime)
|
||||
// A parent cancellation (e.g., another race won and the coordinator
|
||||
// cancelled the losers) is not an upstream failure. Check both the
|
||||
// error chain and the parent context: a transport may surface the
|
||||
// cancellation as a read/deadline error rather than context.Canceled.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
|
||||
}
|
||||
failure := u.handleUpstreamError(err, upstream, startTime)
|
||||
u.markUpstreamFail(upstream, failure.reason)
|
||||
return raceResult{}, failure
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
u.markUpstreamFail(upstream, "no response")
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
resutil.SetMeta(w, "ede", edeName(code))
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
// healthEntry returns the mutable health record for addr, lazily creating
|
||||
// the map and the entry. Caller must hold u.healthMu.
|
||||
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
|
||||
if u.health == nil {
|
||||
u.health = make(map[netip.AddrPort]*UpstreamHealth)
|
||||
}
|
||||
h := u.health[addr]
|
||||
if h == nil {
|
||||
h = &UpstreamHealth{}
|
||||
u.health[addr] = h
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
|
||||
u.healthMu.Lock()
|
||||
defer u.healthMu.Unlock()
|
||||
h := u.healthEntry(addr)
|
||||
h.LastOk = time.Now()
|
||||
h.LastFail = time.Time{}
|
||||
h.LastErr = ""
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
|
||||
u.healthMu.Lock()
|
||||
defer u.healthMu.Unlock()
|
||||
h := u.healthEntry(addr)
|
||||
h.LastFail = time.Now()
|
||||
h.LastErr = reason
|
||||
}
|
||||
|
||||
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
|
||||
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
u.healthMu.RLock()
|
||||
defer u.healthMu.RUnlock()
|
||||
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
|
||||
for k, v := range u.health {
|
||||
out[k] = *v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
||||
@@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
if proto != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", proto)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
@@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
||||
|
||||
if err := w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||
totalUpstreams := len(u.upstreamServers)
|
||||
totalUpstreams := len(u.flatUpstreams())
|
||||
failedCount := len(failures)
|
||||
failureSummary := formatFailures(failures)
|
||||
|
||||
@@ -434,119 +633,6 @@ func edeName(code uint16) string {
|
||||
return fmt.Sprintf("EDE %d", code)
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query
|
||||
if u.successCount.Load() > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var success bool
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errs *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
wg.Add(1)
|
||||
go func(upstream netip.AddrPort) {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errs = multierror.Append(errs, err)
|
||||
mu.Unlock()
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
success = true
|
||||
mu.Unlock()
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errs.ErrorOrNil())
|
||||
|
||||
if u.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_WARNING,
|
||||
proto.SystemEvent_DNS,
|
||||
"All upstream servers failed (probe failed)",
|
||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||
map[string]string{"upstreams": u.upstreamServersString()},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
|
||||
func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
exponentialBackOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.1,
|
||||
MaxInterval: u.reactivatePeriod,
|
||||
MaxElapsedTime: 0,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
operation := func() error {
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.mutex.Lock()
|
||||
u.disabled = false
|
||||
u.mutex.Unlock()
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
//
|
||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||
@@ -558,45 +644,6 @@ func isTimeout(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) disable(err error) {
|
||||
if u.disabled {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.waitUntilResponse()
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
var servers []string
|
||||
for _, server := range u.upstreamServers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
if externalCtx != nil {
|
||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||
defer stop2()
|
||||
}
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
@@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
t time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err := client.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||
}
|
||||
@@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
|
||||
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
|
||||
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
|
||||
// so net.Dialer doesn't reject the TCP dial with "mismatched local
|
||||
// address type".
|
||||
func toTCPClient(c *dns.Client) *dns.Client {
|
||||
tcp := *c
|
||||
tcp.Net = protoTCP
|
||||
if tcp.Dialer == nil {
|
||||
return &tcp
|
||||
}
|
||||
d := *tcp.Dialer
|
||||
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
|
||||
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
|
||||
}
|
||||
tcp.Dialer = &d
|
||||
return &tcp
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
@@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
||||
return bestMatch
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
// haMapRouteCount returns the total number of routes across all HA
|
||||
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
|
||||
// routes per key, so len(hm) is the number of HA groups, not routes.
|
||||
func haMapRouteCount(hm route.HAMap) int {
|
||||
total := 0
|
||||
for _, routes := range hm {
|
||||
total += len(routes)
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
return total
|
||||
}
|
||||
|
||||
// haMapContains checks whether ip is covered by any concrete prefix in
|
||||
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
|
||||
// routes carry a placeholder Network that can't be prefix-checked, so we
|
||||
// can't know at this point whether ip is reached through one. Callers
|
||||
// decide how to interpret the unknown: health projection treats it as
|
||||
// "possibly routed" to avoid emitting false-positive warnings during
|
||||
// startup, while iOS dial selection requires a concrete match before
|
||||
// binding to the tunnel.
|
||||
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
|
||||
for _, routes := range hm {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
haveDynamic = true
|
||||
continue
|
||||
}
|
||||
if r.Network.Contains(ip) {
|
||||
return true, haveDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, haveDynamic
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -26,9 +27,9 @@ func newUpstreamResolver(
|
||||
_ WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
c := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
hostsDNSHolder: hostsDNSHolder,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -24,9 +25,9 @@ func newUpstreamResolver(
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
@@ -27,9 +28,9 @@ func newUpstreamResolver(
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
addr := u.wgIface.Address()
|
||||
var routed bool
|
||||
if u.selectedRoutes != nil {
|
||||
// Only a concrete prefix match binds to the tunnel: dialing
|
||||
// through a private client for an upstream we can't prove is
|
||||
// routed would break public resolvers.
|
||||
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
|
||||
}
|
||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||
addr.IPv6Net.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
routed
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||
}
|
||||
}
|
||||
resolver.upstreamServers = servers
|
||||
resolver.addRace(servers)
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
// exchange mock implementation of exchange from upstreamResolver
|
||||
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
msg *dns.Msg
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
r, ok := c.responses[upstream]
|
||||
if !ok {
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolver{
|
||||
err: dns.ErrTime,
|
||||
r: new(dns.Msg),
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: context.TODO(),
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: time.Microsecond * 100,
|
||||
}
|
||||
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
||||
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
||||
|
||||
failed := false
|
||||
resolver.deactivate = func(error) {
|
||||
failed = true
|
||||
// After deactivation, make the mock client work again
|
||||
mockClient.err = nil
|
||||
}
|
||||
|
||||
reactivated := false
|
||||
resolver.reactivate = func() {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ProbeAvailability(context.TODO())
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
t.Errorf("resolver should be Disabled")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
if !reactivated {
|
||||
t.Errorf("expected that resolving was reactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
t.Errorf("should be enabled")
|
||||
if r.delay > 0 {
|
||||
select {
|
||||
case <-time.After(r.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, c.rtt, ctx.Err()
|
||||
}
|
||||
}
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: trackingClient,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamServers: []netip.AddrPort{upstream},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{upstream})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
|
||||
// configured for the same domain, with one broken group. The merge+race
|
||||
// path should answer as fast as the working group and not pay the timeout
|
||||
// of the broken one on every query.
|
||||
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||
broken := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
working := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
successAnswer := "192.0.2.100"
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
// Force the broken upstream to only unblock via timeout /
|
||||
// cancellation so the assertion below can't pass if races
|
||||
// were run serially.
|
||||
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
|
||||
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{broken})
|
||||
resolver.addRace([]netip.AddrPort{working})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
start := time.Now()
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||
require.NotEmpty(t, responseMSG.Answer)
|
||||
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||
// Working group answers in a single RTT; the broken group's
|
||||
// timeout (100ms) must not block the response.
|
||||
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
|
||||
// resolver returns SERVFAIL rather than leaking a partial response.
|
||||
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a})
|
||||
resolver.addRace([]netip.AddrPort{b})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
require.NotNil(t, responseMSG)
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking verifies that query-path results are
|
||||
// recorded into per-upstream health, which is what projects back to
|
||||
// NSGroupState for status reporting.
|
||||
func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
ok := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
bad := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
|
||||
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{ok, bad})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, ok)
|
||||
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
|
||||
assert.Empty(t, health[ok].LastErr)
|
||||
|
||||
// bad upstream was never tried because ok answered first; its health
|
||||
// should remain unset.
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
var receivedUDPSize atomic.Uint32
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
receivedUDPSize.Store(uint32(opt.UDPSize()))
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
@@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: tracking,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamServers: []upstreamRace{{upstream1, upstream2}},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
|
||||
@@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
@@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
go e.dnsServer.ProbeAvailability()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
|
||||
@@ -53,6 +53,7 @@ type Manager interface {
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetSelectedClientRoutes() route.HAMap
|
||||
GetActiveClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
}
|
||||
|
||||
// GetActiveClientRoutes returns the subset of selected client routes
|
||||
// that are currently reachable: the route's peer is Connected and is
|
||||
// the one actively carrying the route (not just an HA sibling).
|
||||
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
recorder := m.statusRecorder
|
||||
m.mux.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return selected
|
||||
}
|
||||
|
||||
out := make(route.HAMap, len(selected))
|
||||
for id, routes := range selected {
|
||||
for _, r := range routes {
|
||||
st, err := recorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if st.ConnStatus != peer.StatusConnected {
|
||||
continue
|
||||
}
|
||||
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
|
||||
continue
|
||||
}
|
||||
out[id] = routes
|
||||
break
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
||||
}
|
||||
|
||||
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
|
||||
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
|
||||
if len(routes) == 0 {
|
||||
return false
|
||||
}
|
||||
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {
|
||||
|
||||
@@ -19,6 +19,7 @@ type MockManager struct {
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetSelectedClientRoutesFunc func() route.HAMap
|
||||
GetActiveClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
|
||||
if m.GetActiveClientRoutesFunc != nil {
|
||||
return m.GetActiveClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -12,10 +13,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
exitNodeCIDR = "0.0.0.0/0"
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
mu sync.RWMutex
|
||||
deselectedRoutes map[route.NetID]struct{}
|
||||
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
isSelected := !deselected
|
||||
return isSelected
|
||||
return rs.isSelectedLocked(routeID)
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
|
||||
filtered := route.HAMap{}
|
||||
for id, rt := range routes {
|
||||
netID := id.NetID()
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
if !deselected {
|
||||
if !rs.isDeselectedLocked(id.NetID()) {
|
||||
filtered[id] = rt
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
|
||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
|
||||
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
|
||||
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
|
||||
// transparently apply to the synthesized v6 entry.
|
||||
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
|
||||
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
||||
filtered := make(route.HAMap, len(routes))
|
||||
for id, rt := range routes {
|
||||
netID := id.NetID()
|
||||
if rs.isDeselected(netID) {
|
||||
if rs.isDeselectedLocked(netID) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
|
||||
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
|
||||
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
|
||||
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
|
||||
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
|
||||
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
|
||||
name := string(id)
|
||||
if !strings.HasSuffix(name, route.V6ExitSuffix) {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.selectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.deselectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return !deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return true
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
return deselected || rs.deselectAll
|
||||
return deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
}
|
||||
|
||||
func isExitNode(rt []*route.Route) bool {
|
||||
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
|
||||
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) applyExitNodeFilter(
|
||||
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
|
||||
rt []*route.Route,
|
||||
out route.HAMap,
|
||||
) {
|
||||
|
||||
if rs.hasUserSelections() {
|
||||
// user made explicit selects/deselects
|
||||
if rs.IsSelected(netID) {
|
||||
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
|
||||
// drops the synthesized v6 entry that lacks its own explicit state.
|
||||
effective := rs.effectiveNetID(netID)
|
||||
if rs.hasUserSelectionForRouteLocked(effective) {
|
||||
if rs.isSelectedLocked(effective) {
|
||||
out[id] = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
|
||||
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||
sel := collectSelected(rt)
|
||||
if len(sel) > 0 {
|
||||
out[id] = sel
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelections() bool {
|
||||
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
|
||||
}
|
||||
|
||||
func collectSelected(rt []*route.Route) []*route.Route {
|
||||
var sel []*route.Route
|
||||
for _, r := range rt {
|
||||
|
||||
@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
|
||||
assert.Len(t, filtered, 0) // No routes should be selected
|
||||
}
|
||||
|
||||
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
|
||||
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
|
||||
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
|
||||
// v4 base, so legacy persisted selections that predate v6 pairing transparently
|
||||
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
|
||||
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
|
||||
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
|
||||
|
||||
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
|
||||
|
||||
// unrelated v6 with no v4 base touched is unaffected
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
|
||||
})
|
||||
|
||||
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
|
||||
// via the mirror; the mirror only applies in exit-node code paths.
|
||||
assert.False(t, rs.IsSelected("corp"))
|
||||
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
|
||||
})
|
||||
|
||||
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
routes := route.HAMap{
|
||||
"exit1|0.0.0.0/0": {v4Route},
|
||||
"exit1-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
|
||||
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
|
||||
})
|
||||
|
||||
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
// A route literally named "exit1-something" must not pair-resolve.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||
})
|
||||
|
||||
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
routes := route.HAMap{
|
||||
"exit1|0.0.0.0/0": {v4Route},
|
||||
"exit1-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
|
||||
})
|
||||
|
||||
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
// A non-default-route entry named "corp-v6" is not an exit node and
|
||||
// must not be skipped because its v4 base "corp" is deselected.
|
||||
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
|
||||
routes := route.HAMap{
|
||||
"corp-v6|10.0.0.0/8": {corpV6},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
|
||||
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
|
||||
// SkipAutoApply flag governs each untouched route independently, even when
|
||||
// the user has explicit selections on other routes.
|
||||
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
|
||||
autoApplied := &route.Route{
|
||||
NetID: "Auto",
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
SkipAutoApply: false,
|
||||
}
|
||||
skipApply := &route.Route{
|
||||
NetID: "Skip",
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
SkipAutoApply: true,
|
||||
}
|
||||
routes := route.HAMap{
|
||||
"Auto|0.0.0.0/0": {autoApplied},
|
||||
"Skip|0.0.0.0/0": {skipApply},
|
||||
}
|
||||
|
||||
rs := routeselector.NewRouteSelector()
|
||||
// User makes an unrelated explicit selection elsewhere.
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
|
||||
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
|
||||
}
|
||||
|
||||
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
|
||||
// as exit nodes by the selector's filter path.
|
||||
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
|
||||
v6Exit := &route.Route{
|
||||
NetID: "V6Only",
|
||||
Network: netip.MustParsePrefix("::/0"),
|
||||
SkipAutoApply: true,
|
||||
}
|
||||
routes := route.HAMap{
|
||||
"V6Only|::/0": {v6Exit},
|
||||
}
|
||||
|
||||
rs := routeselector.NewRouteSelector()
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
|
||||
}
|
||||
|
||||
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||
initialRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
|
||||
|
||||
@@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
|
||||
// enough for teardown to finish while staying under that deadline.
|
||||
timeout := time.NewTimer(20 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
hostDNS := []netip.AddrPort{
|
||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
||||
}
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
|
||||
}
|
||||
|
||||
func isDefaultRoute(routeRange string) bool {
|
||||
return routeRange == "0.0.0.0/0" || routeRange == "::/0"
|
||||
// routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0",
|
||||
// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry.
|
||||
for _, part := range strings.Split(routeRange, ",") {
|
||||
switch strings.TrimSpace(part) {
|
||||
case "0.0.0.0/0", "::/0":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
@@ -47,6 +48,11 @@ type EphemeralManager struct {
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
|
||||
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
|
||||
// no-op when the receiver is nil so deployments without an app
|
||||
// metrics provider work unchanged.
|
||||
metrics *telemetry.EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
@@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetrics attaches a metrics collector. Safe to call once before
|
||||
// LoadInitialPeers; later attachment is fine but earlier loads won't be
|
||||
// reflected in the gauge. Pass nil to detach.
|
||||
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
|
||||
e.peersLock.Lock()
|
||||
e.metrics = m
|
||||
e.peersLock.Unlock()
|
||||
}
|
||||
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
@@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.removePeer(peer.ID)
|
||||
if e.removePeer(peer.ID) {
|
||||
e.metrics.DecPending(1)
|
||||
}
|
||||
|
||||
// stop the unnecessary timer
|
||||
if e.headPeer == nil && e.timer != nil {
|
||||
@@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
e.metrics.IncPending()
|
||||
if e.timer == nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
@@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
e.metrics.AddPending(int64(len(peers)))
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||
}
|
||||
@@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
// Drop the gauge by the number of entries we just took off the list,
|
||||
// regardless of whether the subsequent DeletePeers call succeeds. The
|
||||
// list invariant is what the gauge tracks; failed delete batches are
|
||||
// counted separately via CountCleanupError so we can still see them.
|
||||
if len(deletePeers) > 0 {
|
||||
e.metrics.CountCleanupRun()
|
||||
e.metrics.DecPending(int64(len(deletePeers)))
|
||||
}
|
||||
|
||||
peerIDsPerAccount := make(map[string][]string)
|
||||
for id, p := range deletePeers {
|
||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||
@@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||
e.metrics.CountCleanupError()
|
||||
continue
|
||||
}
|
||||
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
|
||||
e.tailPeer = ep
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) removePeer(id string) {
|
||||
// removePeer drops the entry from the linked list. Returns true if a
|
||||
// matching entry was found and removed so callers can keep the pending
|
||||
// metric gauge in sync.
|
||||
func (e *EphemeralManager) removePeer(id string) bool {
|
||||
if e.headPeer == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if e.headPeer.id == id {
|
||||
@@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
for p := e.headPeer; p.next != nil; p = p.next {
|
||||
@@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
e.tailPeer = p
|
||||
}
|
||||
p.next = p.next.next
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
|
||||
@@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get public cluster addresses: %w", err)
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
|
||||
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
|
||||
for _, addr := range byopAddresses {
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
merged = append(merged, addr)
|
||||
}
|
||||
for _, addr := range publicAddresses {
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
merged = append(merged, addr)
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
|
||||
@@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"shared.example.com", "byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
@@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
@@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "public cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
@@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ type store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
|
||||
@@ -42,10 +42,35 @@ func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
|
||||
// ClusterType is the source of a proxy cluster.
|
||||
type ClusterType string
|
||||
|
||||
const (
|
||||
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
|
||||
// at least one proxy row in the cluster carries a non-NULL account_id.
|
||||
ClusterTypeAccount ClusterType = "account"
|
||||
// ClusterTypeShared is a cluster operated by NetBird and shared across
|
||||
// accounts — all proxy rows in the cluster have account_id IS NULL.
|
||||
ClusterTypeShared ClusterType = "shared"
|
||||
)
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
//
|
||||
// Online and ConnectedProxies derive from the same 2-min active window
|
||||
// the rest of the module uses, but Cluster rows are not gated on it —
|
||||
// the cluster listing surfaces offline clusters too so operators can
|
||||
// see and clean them up. The 1-hour heartbeat reaper still bounds the
|
||||
// table eventually.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
// Capability flags. *bool because nil means "no proxy reported a
|
||||
// capability for this cluster" — the dashboard renders these as
|
||||
// unknown rather than false.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
|
||||
@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -122,21 +122,6 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetAllServices mocks base method.
|
||||
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -152,19 +137,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
// GetClusters mocks base method.
|
||||
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
// GetClusters indicates an expected call of GetClusters.
|
||||
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
@@ -197,6 +182,21 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -196,10 +196,14 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
Type: api.ProxyClusterType(c.Type),
|
||||
Online: c.Online,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ type ClusterDeriver interface {
|
||||
type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -112,8 +113,12 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeReaper.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
// GetClusters returns every proxy cluster visible to the account
|
||||
// (shared + its own BYOP), regardless of whether any proxy in the
|
||||
// cluster is currently heartbeating. Each cluster is enriched with the
|
||||
// capability flags reported by its active proxies so the dashboard can
|
||||
// render feature support without a second round-trip.
|
||||
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -122,7 +127,18 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
clusters, err := m.store.GetProxyClusters(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range clusters {
|
||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
@@ -306,6 +322,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if svc.Domain != "" {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
@@ -321,10 +341,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
@@ -435,6 +451,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
|
||||
return err
|
||||
@@ -448,10 +468,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
@@ -552,10 +568,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
svcForCaps.ProxyCluster = effectiveCluster
|
||||
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
|
||||
|
||||
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate subdomain requirement *before* the transaction: the underlying
|
||||
// capability lookup talks to the main DB pool, and SQLite's single-connection
|
||||
// pool would self-deadlock if this ran while the tx already held the only
|
||||
// connection.
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
@@ -585,7 +613,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -603,17 +631,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
|
||||
return err
|
||||
@@ -628,9 +652,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
@@ -638,20 +659,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
|
||||
// handleDomainChange validates the new domain is free inside the transaction
|
||||
// and applies the pre-resolved cluster (computed outside the tx by
|
||||
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
|
||||
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
|
||||
// because the transaction already holds the only connection.
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
|
||||
} else {
|
||||
svc.ProxyCluster = newCluster
|
||||
}
|
||||
if effectiveCluster != "" {
|
||||
svc.ProxyCluster = effectiveCluster
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
}
|
||||
|
||||
// HTTP/HTTPS: build full URL
|
||||
hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]")
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Host: bracketIPv6Host(hostNoBrackets),
|
||||
Path: "/",
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
|
||||
targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
@@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
return pathMappings
|
||||
}
|
||||
|
||||
// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as
|
||||
// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6
|
||||
// addresses are bracketed too since their textual form contains colons.
|
||||
func bracketIPv6Host(host string) string {
|
||||
if strings.HasPrefix(host, "[") {
|
||||
return host
|
||||
}
|
||||
if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() {
|
||||
return "[" + host + "]"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
switch op {
|
||||
case Create:
|
||||
|
||||
@@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
port: 80,
|
||||
wantTarget: "https://10.0.0.1:80/",
|
||||
},
|
||||
{
|
||||
name: "domain host without port is unchanged",
|
||||
protocol: "http",
|
||||
host: "example.com",
|
||||
port: 0,
|
||||
wantTarget: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "domain host with non-default port is unchanged",
|
||||
protocol: "http",
|
||||
host: "example.com",
|
||||
port: 8080,
|
||||
wantTarget: "http://example.com:8080/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe:1::3]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with default port omits port and brackets host",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 80,
|
||||
wantTarget: "http://[fb00:cafe:1::3]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with non-default port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1::3",
|
||||
port: 8080,
|
||||
wantTarget: "http://[fb00:cafe:1::3]:8080/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "::1",
|
||||
port: 0,
|
||||
wantTarget: "http://[::1]/",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with 5-digit port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe::1",
|
||||
port: 18080,
|
||||
wantTarget: "http://[fb00:cafe::1]:18080/",
|
||||
},
|
||||
{
|
||||
name: "pre-bracketed ipv6 without port stays single-bracketed",
|
||||
protocol: "http",
|
||||
host: "[fb00:cafe::1]",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe::1]/",
|
||||
},
|
||||
{
|
||||
name: "pre-bracketed ipv6 with port is not double-bracketed",
|
||||
protocol: "http",
|
||||
host: "[fb00:cafe::1]",
|
||||
port: 8080,
|
||||
wantTarget: "http://[fb00:cafe::1]:8080/",
|
||||
},
|
||||
{
|
||||
name: "v4-mapped ipv6 host without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "::ffff:10.0.0.1",
|
||||
port: 0,
|
||||
wantTarget: "http://[::ffff:10.0.0.1]/",
|
||||
},
|
||||
{
|
||||
name: "full-form 8-group ipv6 without port is bracketed",
|
||||
protocol: "http",
|
||||
host: "fb00:cafe:1:0:0:0:0:3",
|
||||
port: 0,
|
||||
wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
|
||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return Create(s, func() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
if metrics := s.Metrics(); metrics != nil {
|
||||
em.SetMetrics(metrics.EphemeralPeersMetrics())
|
||||
}
|
||||
return em
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -136,9 +137,12 @@ type proxyConnection struct {
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// syncStream is set when the proxy connected via SyncMappings.
|
||||
// When non-nil, the sender goroutine uses this instead of stream.
|
||||
syncStream proto.ProxyService_SyncMappingsServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
@@ -206,145 +210,322 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// proxyConnectParams holds the validated parameters extracted from either
|
||||
// a GetMappingUpdateRequest or a SyncMappingsInit message.
|
||||
type proxyConnectParams struct {
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||
ctx := stream.Context()
|
||||
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = req.GetCapabilities()
|
||||
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
log.Infof("New proxy connection from %s", peerInfo)
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
stream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxyID := req.GetProxyId()
|
||||
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
|
||||
}
|
||||
|
||||
// SyncMappings implements the bidirectional SyncMappings RPC.
|
||||
// It mirrors GetMappingUpdate but provides application-level back-pressure:
|
||||
// management waits for an ack from the proxy before sending the next batch.
|
||||
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
init, err := recvSyncInit(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = init.GetCapabilities()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
syncStream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
go s.drainRecv(stream, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
|
||||
}
|
||||
|
||||
// recvSyncInit receives and validates the first message on a SyncMappings stream.
|
||||
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
|
||||
firstMsg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
|
||||
}
|
||||
init := firstMsg.GetInit()
|
||||
if init == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
|
||||
}
|
||||
return init, nil
|
||||
}
|
||||
|
||||
// validateProxyConnect validates the proxy ID and address, and checks cluster
|
||||
// address availability for account-scoped tokens.
|
||||
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
|
||||
if proxyID == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
if !isProxyAddressValid(address) {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
proxyAddress := req.GetAddress()
|
||||
if !isProxyAddressValid(proxyAddress) {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
|
||||
}
|
||||
}
|
||||
|
||||
return proxyConnectParams{proxyID: proxyID, address: address}, nil
|
||||
}
|
||||
|
||||
// registerProxyConnection creates a proxyConnection, registers it with the
|
||||
// proxy manager and cluster, and stores it in connectedProxies. The caller
|
||||
// provides a partially initialised connSeed with stream-specific fields set;
|
||||
// the remaining fields are filled in here.
|
||||
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
|
||||
var accountID *string
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
if token := GetProxyTokenFromContext(ctx); token != nil {
|
||||
if token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
}
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
s.supersedePriorConnection(params.proxyID, sessionID)
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
connSeed.proxyID = params.proxyID
|
||||
connSeed.sessionID = sessionID
|
||||
connSeed.address = params.address
|
||||
connSeed.accountID = accountID
|
||||
connSeed.tokenID = tokenID
|
||||
connSeed.capabilities = params.capabilities
|
||||
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
|
||||
connSeed.ctx = connCtx
|
||||
connSeed.cancel = cancel
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
if c := params.capabilities; c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
s.connectedProxies.Store(params.proxyID, connSeed)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
return connSeed, proxyRecord, nil
|
||||
}
|
||||
|
||||
// supersedePriorConnection cancels any existing connection for the given proxy.
|
||||
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": newSessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
// cleanupFailedSnapshot removes the connection from the cluster and store
|
||||
// after a snapshot send failure.
|
||||
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
|
||||
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
conn.cancel()
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
// drainRecv consumes and discards messages from a bidirectional stream.
|
||||
// The proxy sends an ack for every incremental update; we don't need them
|
||||
// after the snapshot phase. Recv errors are forwarded to errChan.
|
||||
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
|
||||
// and wait for termination. When bidi is true, normal stream closure (EOF,
|
||||
// canceled) is treated as a clean disconnect rather than an error.
|
||||
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": conn.proxyID,
|
||||
"session_id": conn.sessionID,
|
||||
"address": conn.address,
|
||||
"cluster_addr": conn.address,
|
||||
"account_id": conn.accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
defer s.disconnectProxy(conn)
|
||||
go s.heartbeat(conn.ctx, conn, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
if bidi && isStreamClosed(err) {
|
||||
log.Infof("Proxy %s stream closed", conn.proxyID)
|
||||
return nil
|
||||
}
|
||||
log.Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
|
||||
case <-conn.ctx.Done():
|
||||
log.Infof("Proxy %s context canceled", conn.proxyID)
|
||||
return conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// disconnectProxy removes the connection from cluster and store, unless it
|
||||
// has already been superseded by a newer connection.
|
||||
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
conn.cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||
}
|
||||
|
||||
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
|
||||
// one batch, then waits for the proxy to ack before sending the next.
|
||||
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive ack: %w", err)
|
||||
}
|
||||
if msg.GetAck() == nil {
|
||||
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
@@ -381,6 +562,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
@@ -394,6 +578,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
@@ -425,18 +616,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := s.GetOIDCValidationConfig()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
@@ -457,12 +644,26 @@ func isProxyAddressValid(addr string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy
|
||||
// isStreamClosed returns true for errors that indicate normal stream
|
||||
// termination: io.EOF, context cancellation, or gRPC Canceled.
|
||||
func isStreamClosed(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
return status.Code(err) == codes.Canceled
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy.
|
||||
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
|
||||
// otherwise the legacy GetMappingUpdateResponse stream is used.
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -472,6 +673,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a mapping update on whichever stream the proxy connected with.
|
||||
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
|
||||
if conn.syncStream != nil {
|
||||
return conn.syncStream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: resp.Mapping,
|
||||
InitialSyncComplete: resp.InitialSyncComplete,
|
||||
})
|
||||
}
|
||||
return conn.stream.Send(resp)
|
||||
}
|
||||
|
||||
// SendAccessLog processes access log from proxy
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
@@ -538,8 +750,8 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
|
||||
@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
type hookingStream struct {
|
||||
grpc.ServerStream
|
||||
onSend func(*proto.GetMappingUpdateResponse)
|
||||
}
|
||||
|
||||
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
if s.onSend != nil {
|
||||
s.onSend(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *hookingStream) Context() context.Context { return context.Background() }
|
||||
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *hookingStream) SendMsg(any) error { return nil }
|
||||
func (s *hookingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6
|
||||
const ttl = 100 * time.Millisecond
|
||||
const sendDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
var validateErrs []error
|
||||
stream := &hookingStream{
|
||||
onSend: func(resp *proto.GetMappingUpdateResponse) {
|
||||
for _, m := range resp.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
|
||||
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
|
||||
}
|
||||
}
|
||||
time.Sleep(sendDelay)
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
@@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
uncanceledCTX := context.WithoutCancel(ctx)
|
||||
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||
|
||||
411
management/internals/shared/grpc/sync_mappings_test.go
Normal file
411
management/internals/shared/grpc/sync_mappings_test.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
|
||||
// sent messages and returns pre-loaded ack responses from Recv.
|
||||
type syncRecordingStream struct {
|
||||
grpc.ServerStream
|
||||
|
||||
mu sync.Mutex
|
||||
sent []*proto.SyncMappingsResponse
|
||||
recvMsgs []*proto.SyncMappingsRequest
|
||||
recvIdx int
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sent = append(s.sent, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.recvIdx >= len(s.recvMsgs) {
|
||||
return nil, fmt.Errorf("no more recv messages")
|
||||
}
|
||||
msg := s.recvMsgs[s.recvIdx]
|
||||
s.recvIdx++
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *syncRecordingStream) SendMsg(any) error { return nil }
|
||||
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func ackMsg() *proto.SyncMappingsRequest {
|
||||
return &proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
|
||||
|
||||
assert.Len(t, stream.sent[0].Mapping, 3)
|
||||
assert.False(t, stream.sent[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[1].Mapping, 3)
|
||||
assert.False(t, stream.sent[1].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[2].Mapping, 1)
|
||||
assert.True(t, stream.sent[2].InitialSyncComplete)
|
||||
|
||||
// All 3 acks consumed — including the final batch.
|
||||
assert.Equal(t, 3, stream.recvIdx)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, totalServices)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.sent[0].Mapping)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// No acks available — Recv will return error.
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "receive ack")
|
||||
// First batch should have been sent before the error.
|
||||
require.Len(t, stream.sent, 1)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// Send an init message instead of an ack.
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{
|
||||
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expected ack")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
|
||||
// Verify batches are sent strictly sequentially — batch N+1 is not sent
|
||||
// until the ack for batch N is received, including the final batch.
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6 // 3 batches, 3 acks
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
var mu sync.Mutex
|
||||
var events []string
|
||||
|
||||
// Build a stream that logs send/recv events so we can verify ordering.
|
||||
ackCh := make(chan struct{}, 3)
|
||||
stream := &orderTrackingStream{
|
||||
mu: &mu,
|
||||
events: &events,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
// Feed acks asynchronously after a short delay to simulate real proxy.
|
||||
go func() {
|
||||
for range 3 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
ackCh <- struct{}{}
|
||||
}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Expected: send, recv-ack, send, recv-ack, send, recv-ack.
|
||||
require.Len(t, events, 6)
|
||||
assert.Equal(t, "send", events[0])
|
||||
assert.Equal(t, "recv", events[1])
|
||||
assert.Equal(t, "send", events[2])
|
||||
assert.Equal(t, "recv", events[3])
|
||||
assert.Equal(t, "send", events[4])
|
||||
assert.Equal(t, "recv", events[5])
|
||||
}
|
||||
|
||||
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
|
||||
// an ack is signaled via ackCh.
|
||||
type orderTrackingStream struct {
|
||||
grpc.ServerStream
|
||||
mu *sync.Mutex
|
||||
events *[]string
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "send")
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "recv")
|
||||
s.mu.Unlock()
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
|
||||
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *orderTrackingStream) SendMsg(any) error { return nil }
|
||||
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
const ttl = 100 * time.Millisecond
|
||||
const ackDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
// Build a stream that validates tokens immediately on Send, then
|
||||
// delays the ack to ensure the next batch's tokens are generated fresh.
|
||||
var validateErrs []error
|
||||
ackCh := make(chan struct{}, 2)
|
||||
stream := &tokenValidatingSyncStream{
|
||||
tokenStore: s.tokenStore,
|
||||
validateErrs: &validateErrs,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Delay first ack so that if tokens were all generated upfront they'd expire.
|
||||
time.Sleep(ackDelay)
|
||||
ackCh <- struct{}{}
|
||||
// Final batch ack — immediate.
|
||||
ackCh <- struct{}{}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid: per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
type tokenValidatingSyncStream struct {
|
||||
grpc.ServerStream
|
||||
tokenStore *OneTimeTokenStore
|
||||
validateErrs *[]error
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
for _, mapping := range m.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
|
||||
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
|
||||
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
|
||||
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
|
||||
},
|
||||
InitialSyncComplete: true,
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, 1)
|
||||
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-2", AccountId: "acct-2"},
|
||||
},
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.messages, 1)
|
||||
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
|
||||
}
|
||||
@@ -322,21 +322,29 @@ func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Conte
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *testValidateSessionServiceManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -291,10 +291,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Canonicalize the incoming range so a caller-supplied prefix with host bits
|
||||
// (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net.
|
||||
newSettings.NetworkRange = newSettings.NetworkRange.Masked()
|
||||
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
var reloadReverseProxy bool
|
||||
var effectiveOldNetworkRange netip.Prefix
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
@@ -308,6 +313,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
// No lock: the transaction already holds Settings(Update), and network.Net is
|
||||
// only mutated by reallocateAccountPeerIPs, which is reachable only through
|
||||
// this same code path. A Share lock here would extend an unnecessary row lock
|
||||
// and complicate ordering against updatePeerIPv6InTransaction.
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account network: %w", err)
|
||||
}
|
||||
effectiveOldNetworkRange = prefixFromIPNet(network.Net)
|
||||
|
||||
if oldSettings.Extra != nil && newSettings.Extra != nil &&
|
||||
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
|
||||
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
|
||||
@@ -321,7 +336,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
|
||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -396,9 +411,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
|
||||
}
|
||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
|
||||
eventMeta := map[string]any{
|
||||
"old_network_range": oldSettings.NetworkRange.String(),
|
||||
"old_network_range": effectiveOldNetworkRange.String(),
|
||||
"new_network_range": newSettings.NetworkRange.String(),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
@@ -443,6 +458,22 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool {
|
||||
return !slices.Equal(oldGroups, newGroups)
|
||||
}
|
||||
|
||||
// prefixFromIPNet returns the overlay prefix actually allocated on the account
|
||||
// network, or an invalid prefix if none is set. Settings.NetworkRange is a
|
||||
// user-facing override that is empty on legacy accounts, so the effective
|
||||
// range must be read from network.Net to compare against an incoming update.
|
||||
func prefixFromIPNet(ipNet net.IPNet) netip.Prefix {
|
||||
if ipNet.IP == nil {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(ipNet.IP)
|
||||
if !ok {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
ones, _ := ipNet.Mask.Size()
|
||||
return netip.PrefixFrom(addr.Unmap(), ones)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
@@ -1837,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
|
||||
// network map and then marks the peer connected with a session token
|
||||
// derived from syncTime (the moment the gRPC stream opened). Any
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||
if err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
|
||||
// peer disconnected only when the stored SessionStartedAt matches the
|
||||
// nanosecond token derived from streamStartTime — i.e. only when this
|
||||
// is the stream that currently owns the peer's session. A mismatch
|
||||
// means a newer stream has already replaced us, so the disconnect is
|
||||
// dropped.
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if peer.Status.LastSeen.After(streamStartTime) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||
if err != nil {
|
||||
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
return nil
|
||||
@@ -2487,6 +2515,18 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran
|
||||
allowedPeers[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Embedded proxy peers sit outside regular group membership but must
|
||||
// participate in any v6-enabled overlay to reach v6-only peers.
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get peers: %w", err)
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.ProxyMeta.Embedded {
|
||||
allowedPeers[p.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return allowedPeers, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,8 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
|
||||
@@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// OnPeerDisconnected mocks base method.
|
||||
|
||||
@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err, "unable to get peer")
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal the token we passed in")
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
require.NoError(t, err)
|
||||
@@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
|
||||
})
|
||||
|
||||
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||
// Older stream tries to mark disconnect with its own (older) session token —
|
||||
// fencing kicks in and the write is dropped.
|
||||
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected,
|
||||
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||
"peer should remain connected because the stored session is newer than the disconnect token")
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should still hold the winning stream's token")
|
||||
})
|
||||
|
||||
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
|
||||
// fencing protocol under contention: many goroutines race to mark the
|
||||
// same peer connected with distinct session tokens at the same time.
|
||||
// The contract is that the highest token always wins and is what remains
|
||||
// in the store, regardless of execution order.
|
||||
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
const workers = 16
|
||||
base := time.Now().UTC().UnixNano()
|
||||
tokens := make([]int64, workers)
|
||||
for i := range tokens {
|
||||
// Spread tokens by 1ms so the comparison is unambiguous; the
|
||||
// largest is index workers-1.
|
||||
tokens[i] = base + int64(i)*int64(time.Millisecond)
|
||||
}
|
||||
expected := tokens[workers-1]
|
||||
|
||||
var ready sync.WaitGroup
|
||||
ready.Add(workers)
|
||||
var start sync.WaitGroup
|
||||
start.Add(1)
|
||||
var done sync.WaitGroup
|
||||
done.Add(workers)
|
||||
|
||||
// require.* calls t.FailNow which is documented as unsafe from
|
||||
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
|
||||
// races with the WaitGroup). Collect errors here and assert from the
|
||||
// main goroutine after done.Wait().
|
||||
errs := make(chan error, workers)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
token := tokens[i]
|
||||
go func() {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
}()
|
||||
}
|
||||
|
||||
ready.Wait()
|
||||
start.Done()
|
||||
done.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err, "MarkPeerConnected must not error under contention")
|
||||
}
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected after the race")
|
||||
require.Equal(t, expected, peer.Status.SessionStartedAt,
|
||||
"the largest token must win regardless of execution order")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3970,6 +4049,96 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against
|
||||
// peer IP reallocation when a settings update carries the network range that is already
|
||||
// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net
|
||||
// holds the actual allocated overlay; the dashboard backfills the GET response from
|
||||
// network.Net and echoes the value back on PUT, so the diff must be against the
|
||||
// effective range to avoid renumbering every peer on an unrelated settings change.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) {
|
||||
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset")
|
||||
|
||||
network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated")
|
||||
addr, ok := netip.AddrFromSlice(network.Net.IP)
|
||||
require.True(t, ok)
|
||||
ones, _ := network.Net.Mask.Size()
|
||||
effective := netip.PrefixFrom(addr.Unmap(), ones)
|
||||
require.True(t, effective.IsValid())
|
||||
|
||||
before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP}
|
||||
|
||||
// Round-trip the effective range as if the dashboard echoed back the GET-backfilled value.
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: effective,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, len(before))
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID)
|
||||
}
|
||||
|
||||
// Carrying the same range with host bits set must also be a no-op once canonicalized.
|
||||
hostBitsForm := netip.PrefixFrom(peer1.IP, ones)
|
||||
require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking")
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: hostBitsForm,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID)
|
||||
}
|
||||
|
||||
// Omitting NetworkRange (invalid prefix) must also be a no-op.
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID)
|
||||
}
|
||||
|
||||
// Sanity: an actually different range still triggers reallocation.
|
||||
newRange := netip.MustParsePrefix("100.99.0.0/16")
|
||||
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
NetworkRange: newRange,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
|
||||
require.NoError(t, err)
|
||||
for _, p := range peers {
|
||||
assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP)
|
||||
assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) {
|
||||
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -444,7 +444,7 @@ func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain stri
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
func (m *testServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1319,7 +1319,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing router creates it",
|
||||
name: "Update non-existing router returns not found",
|
||||
networkId: "testNetworkId",
|
||||
routerId: "nonExistingRouterId",
|
||||
requestBody: &api.NetworkRouterRequest{
|
||||
@@ -1328,11 +1328,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
|
||||
Metric: 100,
|
||||
Enabled: true,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "nonExistingRouterId", router.Id)
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Update router with both peer and peer_groups",
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
@@ -138,10 +140,13 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
// Build CLI redirect URIs including the device callback. Dex uses the issuer-relative
|
||||
// path (for example, /oauth2/device/callback) when completing the device flow, so
|
||||
// include it explicitly in addition to the legacy bare path and absolute URL.
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer))
|
||||
cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback")
|
||||
|
||||
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
|
||||
dashboardRedirectURIs := c.DashboardRedirectURIs
|
||||
@@ -154,6 +159,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
|
||||
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
|
||||
|
||||
redirectURIs := make([]string, 0)
|
||||
redirectURIs = append(redirectURIs, cliRedirectURIs...)
|
||||
redirectURIs = append(redirectURIs, dashboardRedirectURIs...)
|
||||
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
@@ -179,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: dashboardRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
Name: "NetBird CLI",
|
||||
Public: true,
|
||||
RedirectURIs: cliRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
},
|
||||
},
|
||||
StaticConnectors: c.StaticConnectors,
|
||||
@@ -217,6 +226,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func issuerRelativeDeviceCallback(issuer string) string {
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil || u.Path == "" {
|
||||
return "/device/callback"
|
||||
}
|
||||
return path.Join(u.Path, "/device/callback")
|
||||
}
|
||||
|
||||
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
|
||||
// and because Dex only allows exact matches, we need to make sure we always have both
|
||||
// versions of each provided uri
|
||||
@@ -299,7 +316,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
}
|
||||
|
||||
func validSessionCookieEncryptionKeyLength(length int) bool {
|
||||
|
||||
@@ -314,6 +314,34 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) {
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "https://example.com/oauth2",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(t.TempDir(), "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
yamlConfig, err := config.ToYAMLConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
var cliRedirectURIs []string
|
||||
for _, client := range yamlConfig.StaticClients {
|
||||
if client.ID == staticClientCLI {
|
||||
cliRedirectURIs = client.RedirectURIs
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, cliRedirectURIs)
|
||||
assert.Contains(t, cliRedirectURIs, "/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback")
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
|
||||
@@ -198,7 +198,11 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to insert account")
|
||||
|
||||
account.PeersG = []nbpeer.Peer{
|
||||
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
{
|
||||
AccountID: "1234",
|
||||
Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}},
|
||||
Status: &nbpeer.PeerStatus{LastSeen: time.Now()},
|
||||
},
|
||||
}
|
||||
|
||||
err = db.Save(account).Error
|
||||
|
||||
@@ -38,7 +38,8 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
@@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
|
||||
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
|
||||
// disconnect events. Falls through to nil when no hook is set, which
|
||||
// is the original behaviour.
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
if am.DeleteAccountFunc != nil {
|
||||
|
||||
@@ -34,8 +34,11 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
|
||||
|
||||
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, networks, 1)
|
||||
require.Equal(t, "testNetworkId", networks[0].ID)
|
||||
ids := make([]string, 0, len(networks))
|
||||
for _, n := range networks {
|
||||
ids = append(ids, n.ID)
|
||||
}
|
||||
require.ElementsMatch(t, []string{"testNetworkId", "secondNetworkId"}, ids)
|
||||
}
|
||||
|
||||
func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
|
||||
|
||||
@@ -102,7 +102,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
err = transaction.CreateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
}
|
||||
@@ -162,11 +162,20 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
if network.ID != router.NetworkID {
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
err = transaction.UpdateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
@@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
@@ -210,6 +211,102 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
require.Equal(t, router.Metric, updatedRouter.Metric)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterRejectsCrossAccountID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// Admin of testAccountId tries to update a router that belongs to otherAccountId
|
||||
// by passing the other account's router ID through the URL.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "otherRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedRouter)
|
||||
|
||||
// The other account's router must be untouched.
|
||||
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "otherAccountId", "otherRouterId")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "otherAccountId", stored.AccountID)
|
||||
require.Equal(t, "otherNetworkId", stored.NetworkID)
|
||||
require.Equal(t, "otherPeer", stored.Peer)
|
||||
require.Equal(t, 1, stored.Metric)
|
||||
}
|
||||
|
||||
func Test_CreateRouterRejectsCrossAccountID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// Admin of testAccountId tries to create a router in otherAccountId's network.
|
||||
// The permission check is on router.AccountID (their own), but the network
|
||||
// lookup must fail because (testAccountId, otherNetworkId) does not exist.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "otherNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
createdRouter, err := manager.CreateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, createdRouter)
|
||||
|
||||
// No router should have been created in either account's scope under otherNetworkId.
|
||||
routersInOther, err := s.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, "otherAccountId", "otherNetworkId")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routersInOther, 1)
|
||||
require.Equal(t, "otherRouterId", routersInOther[0].ID)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterRejectsNetworkMismatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// The router exists in testNetworkId, but the caller submits secondNetworkId
|
||||
// (a different network in the same account). The update must be refused.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "secondNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedRouter)
|
||||
|
||||
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "testAccountId", "testRouterId")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "testNetworkId", stored.NetworkID)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testUserId"
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -29,6 +28,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -63,56 +63,64 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
// syncTime is used as the LastSeen timestamp and for stale request detection
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
var err error
|
||||
var skipped bool
|
||||
// MarkPeerConnected marks a peer as connected with optimistic-locked
|
||||
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
|
||||
// is the start time of the gRPC sync stream that owns this update,
|
||||
// expressed as Unix nanoseconds — only the call whose token is greater
|
||||
// than what's stored wins. LastSeen is written by the database itself;
|
||||
// we never pass it down.
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
}()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
|
||||
if connected && !syncTime.After(peer.Status.LastSeen) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
|
||||
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
|
||||
skipped = true
|
||||
return nil
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome)
|
||||
return err
|
||||
})
|
||||
if skipped {
|
||||
}
|
||||
|
||||
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
expired := peer.Status != nil && peer.Status.LoginExpired
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
if err != nil {
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -120,41 +128,60 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = syncTime
|
||||
newStatus.Connected = connected
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
newStatus.LoginExpired = false
|
||||
}
|
||||
peer.Status = newStatus
|
||||
// MarkPeerDisconnected marks a peer as disconnected, but only when the
|
||||
// stored session token matches the one passed in. A mismatch means a
|
||||
// newer stream has already taken ownership of the peer — disconnects from
|
||||
// the older stream are ignored. LastSeen is written by the database.
|
||||
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start))
|
||||
}()
|
||||
|
||||
if geo != nil && realIP != nil {
|
||||
location, err := geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
} else {
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = transaction.SavePeerLocation(ctx, accountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome)
|
||||
return err
|
||||
}
|
||||
|
||||
return oldStatus.LoginExpired, nil
|
||||
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
|
||||
peer.ID, sessionStartedAt)
|
||||
return nil
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return
|
||||
}
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
@@ -762,16 +789,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
newPeer.IP = freeIP
|
||||
|
||||
if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil {
|
||||
var allGroupID string
|
||||
if !peer.ProxyMeta.Embedded {
|
||||
allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All")
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err)
|
||||
} else {
|
||||
// Embedded proxy peers are not group members but participate in any
|
||||
// IPv6-enabled overlay so reverse-proxy traffic reaches v6-only peers.
|
||||
allocate := peer.ProxyMeta.Embedded
|
||||
if !allocate {
|
||||
var allGroupID string
|
||||
if allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, types.GroupAllName); err == nil {
|
||||
allGroupID = allGroup.ID
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err)
|
||||
}
|
||||
allocate = peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID)
|
||||
}
|
||||
if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) {
|
||||
if allocate {
|
||||
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
|
||||
@@ -74,8 +74,19 @@ type ProxyMeta struct {
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
// LastSeen is the last time peer was connected to the management service
|
||||
// LastSeen is the last time the peer status was updated (i.e. the last
|
||||
// time we observed the peer being alive on a sync stream). Written by
|
||||
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
|
||||
LastSeen time.Time
|
||||
// SessionStartedAt records when the currently-active sync stream began,
|
||||
// stored as Unix nanoseconds. It acts as the optimistic-locking token
|
||||
// for status updates: a stream is only allowed to mutate the peer's
|
||||
// status when its own token strictly exceeds the stored token (when connecting)
|
||||
// or matches it exactly (for disconnects). Zero means "no
|
||||
// active session". Integer nanoseconds are used so equality is
|
||||
// precision-safe across drivers, and so the predicates compose to a
|
||||
// single bigint comparison.
|
||||
SessionStartedAt int64 `gorm:"not null;default:0"`
|
||||
// Connected indicates whether peer is connected to the management service or not
|
||||
Connected bool
|
||||
// LoginExpired
|
||||
@@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
// Copy PeerStatus
|
||||
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
|
||||
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
|
||||
// reset the fencing token to zero — that would let any subsequent
|
||||
// SavePeerStatus write reopen the optimistic-lock window.
|
||||
func (p *PeerStatus) Copy() *PeerStatus {
|
||||
return &PeerStatus{
|
||||
LastSeen: p.LastSeen,
|
||||
SessionStartedAt: p.SessionStartedAt,
|
||||
Connected: p.Connected,
|
||||
LoginExpired: p.LoginExpired,
|
||||
RequiresApproval: p.RequiresApproval,
|
||||
|
||||
@@ -2218,6 +2218,9 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
ID: "test-peer-id",
|
||||
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
DNSLabel: "test-peer-dns-label",
|
||||
Status: &nbpeer.PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -497,8 +498,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
|
||||
peerCopy.Status = &peerStatus
|
||||
|
||||
fieldsToUpdate := []string{
|
||||
"peer_status_last_seen", "peer_status_connected",
|
||||
"peer_status_login_expired", "peer_status_required_approval",
|
||||
"peer_status_last_seen", "peer_status_session_started_at",
|
||||
"peer_status_connected", "peer_status_login_expired",
|
||||
"peer_status_requires_approval",
|
||||
}
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Select(fieldsToUpdate).
|
||||
@@ -515,6 +517,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update.
|
||||
// The peer is marked connected with the given session token only when
|
||||
// the stored SessionStartedAt is strictly smaller than the incoming
|
||||
// one — equivalently, when no newer stream has already taken ownership.
|
||||
// The sentinel zero (set on peer creation or after a disconnect) counts
|
||||
// as the smallest possible token. This is the write half of the
|
||||
// fencing protocol described on PeerStatus.SessionStartedAt.
|
||||
//
|
||||
// The post-write side effects in the caller — geo lookup,
|
||||
// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration,
|
||||
// OnPeersUpdated — all run AFTER this method returns and are deliberately
|
||||
// outside the database write so they cannot extend the row-lock window.
|
||||
//
|
||||
// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the
|
||||
// moment the row is written. The caller never supplies LastSeen because
|
||||
// the value would otherwise drift under lock contention — a Go-side
|
||||
// time.Now() taken before the write can land minutes later than the
|
||||
// actual UPDATE under load, which previously caused real ordering bugs.
|
||||
func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Where("peer_status_session_started_at < ?", newSessionStartedAt).
|
||||
Updates(map[string]any{
|
||||
"peer_status_connected": true,
|
||||
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"peer_status_session_started_at": newSessionStartedAt,
|
||||
"peer_status_login_expired": false,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update.
|
||||
// The peer is marked disconnected only when the stored SessionStartedAt
|
||||
// matches the incoming token — meaning the stream that owns the current
|
||||
// session is the one ending. If a newer stream has already replaced the
|
||||
// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at
|
||||
// write time; see MarkPeerConnectedIfNewerSession for the rationale.
|
||||
//
|
||||
// A zero sessionStartedAt is rejected at the call site; the underlying
|
||||
// WHERE on equality would otherwise match every never-connected peer.
|
||||
func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
if sessionStartedAt == 0 {
|
||||
return false, nil
|
||||
}
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Where("peer_status_session_started_at = ?", sessionStartedAt).
|
||||
Updates(map[string]any{
|
||||
"peer_status_connected": false,
|
||||
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"peer_status_session_started_at": int64(0),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
@@ -1722,9 +1787,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
|
||||
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
|
||||
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
|
||||
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
|
||||
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
|
||||
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1`
|
||||
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at,
|
||||
peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip,
|
||||
location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6
|
||||
FROM peers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1737,6 +1803,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
lastLogin, createdAt sql.NullTime
|
||||
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
||||
peerStatusLastSeen sql.NullTime
|
||||
peerStatusSessionStartedAt sql.NullInt64
|
||||
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
|
||||
ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte
|
||||
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
|
||||
@@ -1751,8 +1818,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
|
||||
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
|
||||
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities,
|
||||
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
|
||||
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6)
|
||||
&peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired,
|
||||
&peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID,
|
||||
&proxyEmbedded, &proxyCluster, &ipv6)
|
||||
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
@@ -1779,6 +1847,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
if peerStatusLastSeen.Valid {
|
||||
p.Status.LastSeen = peerStatusLastSeen.Time
|
||||
}
|
||||
if peerStatusSessionStartedAt.Valid {
|
||||
p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64
|
||||
}
|
||||
if peerStatusConnected.Valid {
|
||||
p.Status.Connected = peerStatusConnected.Bool
|
||||
}
|
||||
@@ -2794,12 +2865,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
connStr = filepath.Join(dataDir, filePath)
|
||||
}
|
||||
|
||||
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
|
||||
if hasQuery {
|
||||
connStr += "?" + query
|
||||
} else if runtime.GOOS != "windows" {
|
||||
// Compose query parameters. User-provided ?_busy_timeout (or its mattn alias
|
||||
// ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at
|
||||
// most that long on a lock instead of blocking the only Go-side connection.
|
||||
// mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so
|
||||
// the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared
|
||||
// stays the default on non-Windows for the same reason as before.
|
||||
parsed, _ := url.ParseQuery(query)
|
||||
var defaults []string
|
||||
if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" {
|
||||
defaults = append(defaults, "_busy_timeout=30000")
|
||||
}
|
||||
if !hasQuery && runtime.GOOS != "windows" {
|
||||
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
connStr += "?cache=shared"
|
||||
defaults = append(defaults, "cache=shared")
|
||||
}
|
||||
parts := defaults
|
||||
if hasQuery {
|
||||
parts = append(parts, query)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
connStr += "?" + strings.Join(parts, "&")
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
|
||||
@@ -3402,7 +3488,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
@@ -4229,11 +4315,27 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
|
||||
return netRouter, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
||||
result := s.db.Save(router)
|
||||
func (s *SqlStore) CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
||||
if err := s.db.Create(router).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create network router in store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create network router in store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
||||
result := s.db.
|
||||
Select("*").
|
||||
Where(accountAndIDQueryCondition, router.AccountID, router.ID).
|
||||
Updates(router)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save network router to store")
|
||||
log.WithContext(ctx).Errorf("failed to update network router in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update network router in store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -5650,19 +5752,67 @@ func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, acc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
var clusters []proxy.Cluster
|
||||
// GetProxyClusters returns every cluster the account can see (shared
|
||||
// plus its own BYOP), regardless of whether any proxy in the cluster
|
||||
// is currently heartbeating. Online and ConnectedProxies are derived
|
||||
// from the 2-min active window so the dashboard can render offline
|
||||
// clusters distinctly; the 1-hour heartbeat reaper still removes rows
|
||||
// that go quiet for too long.
|
||||
//
|
||||
// AccountOwned is determined by whether any proxy row in the group
|
||||
// carries a non-NULL account_id; the caller maps that to Cluster.Type.
|
||||
// Capability flags are NOT filled here — the handler enriches them via
|
||||
// the per-cluster capability lookups.
|
||||
func (s *SqlStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
activeCutoff := time.Now().Add(-proxyActiveThreshold)
|
||||
|
||||
type clusterRow struct {
|
||||
ID string
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
Online bool
|
||||
AccountOwned bool
|
||||
}
|
||||
|
||||
var rows []clusterRow
|
||||
result := s.db.Model(&proxy.Proxy{}).
|
||||
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
|
||||
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
|
||||
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
|
||||
Select(
|
||||
"MIN(id) AS id, "+
|
||||
"cluster_address AS address, "+
|
||||
// COUNT(CASE WHEN ... THEN 1 END) counts only non-NULL — i.e. only
|
||||
// rows that satisfy the predicate — so it works portably across
|
||||
// sqlite/postgres/mysql without dialect-specific FILTER syntax.
|
||||
"COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS connected_proxies, "+
|
||||
// MAX(CASE …) > 0 expresses BOOL_OR in a way Postgres tolerates
|
||||
// (Postgres can't MAX a boolean column).
|
||||
"MAX(CASE WHEN status = ? AND last_seen > ? THEN 1 ELSE 0 END) > 0 AS online, "+
|
||||
"MAX(CASE WHEN account_id IS NOT NULL THEN 1 ELSE 0 END) > 0 AS account_owned",
|
||||
proxy.StatusConnected, activeCutoff,
|
||||
proxy.StatusConnected, activeCutoff,
|
||||
).
|
||||
Where("account_id IS NULL OR account_id = ?", accountID).
|
||||
Group("cluster_address").
|
||||
Scan(&clusters)
|
||||
Scan(&rows)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "get active proxy clusters")
|
||||
log.WithContext(ctx).Errorf("failed to get proxy clusters: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "get proxy clusters")
|
||||
}
|
||||
|
||||
clusters := make([]proxy.Cluster, 0, len(rows))
|
||||
for _, r := range rows {
|
||||
c := proxy.Cluster{
|
||||
ID: r.ID,
|
||||
Address: r.Address,
|
||||
Online: r.Online,
|
||||
ConnectedProxies: r.ConnectedProxies,
|
||||
}
|
||||
if r.AccountOwned {
|
||||
c.Type = proxy.ClusterTypeAccount
|
||||
} else {
|
||||
c.Type = proxy.ClusterTypeShared
|
||||
}
|
||||
clusters = append(clusters, c)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
|
||||
109
management/server/store/sql_store_proxy_clusters_test.go
Normal file
109
management/server/store/sql_store_proxy_clusters_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// TestSqlStore_GetProxyClusters_DerivesOnlineAndType guards the
|
||||
// account-visible cluster list against silent regressions in two
|
||||
// dimensions:
|
||||
//
|
||||
// 1. Online derivation: a cluster with one stale and one fresh proxy
|
||||
// is online and counts only the fresh proxy; a cluster whose
|
||||
// proxies all heartbeated outside the 2-min window appears offline
|
||||
// with connected_proxies = 0 (rather than disappearing, which is
|
||||
// what the old query did).
|
||||
// 2. Type derivation: a cluster scoped to the calling account is
|
||||
// reported as `account`; a cluster with account_id IS NULL is
|
||||
// reported as `shared`. Clusters scoped to other accounts must not
|
||||
// leak into the result.
|
||||
//
|
||||
// Capability flags are intentionally not asserted here — they're filled
|
||||
// by the manager (handler) layer from the per-cluster capability
|
||||
// lookups, not by the store query.
|
||||
func TestSqlStore_GetProxyClusters_DerivesOnlineAndType(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
accountID := "acct-clusters"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
|
||||
|
||||
otherAccountID := "acct-other"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, otherAccountID, "user-2", "")))
|
||||
|
||||
acctID := accountID
|
||||
otherID := otherAccountID
|
||||
|
||||
fresh := time.Now().Add(-30 * time.Second)
|
||||
stale := time.Now().Add(-30 * time.Minute)
|
||||
|
||||
mustSave := func(id, cluster string, accID *string, status string, lastSeen time.Time) {
|
||||
require.NoError(t, store.SaveProxy(ctx, &rpproxy.Proxy{
|
||||
ID: id,
|
||||
SessionID: id + "-sess",
|
||||
ClusterAddress: cluster,
|
||||
IPAddress: "10.0.0.1",
|
||||
AccountID: accID,
|
||||
LastSeen: lastSeen,
|
||||
Status: status,
|
||||
}))
|
||||
}
|
||||
|
||||
// shared-mixed: one fresh + one stale proxy → online, connected=1
|
||||
mustSave("p-shared-fresh", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, fresh)
|
||||
mustSave("p-shared-stale", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, stale)
|
||||
|
||||
// shared-offline: only stale proxies → offline, connected=0,
|
||||
// but row must still appear (this is the new semantic — old
|
||||
// query would have dropped it entirely).
|
||||
mustSave("p-shared-off", "shared-offline.netbird.io", nil, rpproxy.StatusConnected, stale)
|
||||
|
||||
// account-online: BYOP cluster owned by acctID, fresh
|
||||
mustSave("p-acct-fresh", "byop.acct.example", &acctID, rpproxy.StatusConnected, fresh)
|
||||
|
||||
// other-account: must not surface for acctID
|
||||
mustSave("p-other", "byop.other.example", &otherID, rpproxy.StatusConnected, fresh)
|
||||
|
||||
clusters, err := store.GetProxyClusters(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
byAddr := map[string]rpproxy.Cluster{}
|
||||
for _, c := range clusters {
|
||||
byAddr[c.Address] = c
|
||||
}
|
||||
|
||||
assert.NotContains(t, byAddr, "byop.other.example",
|
||||
"another account's BYOP cluster must not leak into this account's listing")
|
||||
|
||||
require.Contains(t, byAddr, "shared-mixed.netbird.io")
|
||||
mixed := byAddr["shared-mixed.netbird.io"]
|
||||
assert.Equal(t, rpproxy.ClusterTypeShared, mixed.Type, "shared cluster (account_id IS NULL) must be reported as Type=shared")
|
||||
assert.True(t, mixed.Online, "cluster with a fresh proxy must be online")
|
||||
assert.Equal(t, 1, mixed.ConnectedProxies, "connected_proxies must count only fresh proxies; the stale one should not bump the count")
|
||||
|
||||
require.Contains(t, byAddr, "shared-offline.netbird.io",
|
||||
"offline clusters must still appear so the dashboard can render them — the old GetActiveProxyClusters would have dropped this row, which is the regression this test guards against")
|
||||
offline := byAddr["shared-offline.netbird.io"]
|
||||
assert.Equal(t, rpproxy.ClusterTypeShared, offline.Type)
|
||||
assert.False(t, offline.Online, "no fresh heartbeat → offline")
|
||||
assert.Equal(t, 0, offline.ConnectedProxies, "no fresh proxies → connected_proxies=0")
|
||||
|
||||
require.Contains(t, byAddr, "byop.acct.example")
|
||||
acct := byAddr["byop.acct.example"]
|
||||
assert.Equal(t, rpproxy.ClusterTypeAccount, acct.Type, "BYOP cluster owned by the account must be reported as Type=account")
|
||||
assert.True(t, acct.Online)
|
||||
assert.Equal(t, 1, acct.ConnectedProxies)
|
||||
})
|
||||
}
|
||||
@@ -2399,7 +2399,7 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
func TestSqlStore_CreateNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
@@ -2410,7 +2410,7 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveNetworkRouter(context.Background(), netRouter)
|
||||
err = store.CreateNetworkRouter(context.Background(), netRouter)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, netRouter.ID)
|
||||
@@ -2418,6 +2418,39 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
require.Equal(t, netRouter, savedNetRouter)
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
networkID := "ct286bi7qv930dsrrug0"
|
||||
routerID := "ctc20ji7qv9ck2sebc80"
|
||||
|
||||
netRouter := &routerTypes.NetworkRouter{
|
||||
ID: routerID,
|
||||
AccountID: accountID,
|
||||
NetworkID: networkID,
|
||||
Peer: "",
|
||||
PeerGroups: []string{"net-router-grp"},
|
||||
Masquerade: true,
|
||||
Metric: 42,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err = store.UpdateNetworkRouter(context.Background(), netRouter)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, routerID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, netRouter, savedNetRouter)
|
||||
|
||||
// Updating a router under a different account must not match any row.
|
||||
netRouter.AccountID = "non-existent-account"
|
||||
err = store.UpdateNetworkRouter(context.Background(), netRouter)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
@@ -4592,3 +4625,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(remainingRecords))
|
||||
}
|
||||
|
||||
// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies
|
||||
// that the _busy_timeout DSN parameter took effect at the driver level. Without
|
||||
// this, lock contention on the single SQLite connection waits indefinitely on
|
||||
// the Go side and can be hidden behind the 5-minute transactionTimeout.
|
||||
func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := NewSqliteStore(context.Background(), dir, nil, true)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close(context.Background())
|
||||
})
|
||||
|
||||
sqlDB, err := store.db.DB()
|
||||
require.NoError(t, err)
|
||||
row := sqlDB.QueryRow("PRAGMA busy_timeout")
|
||||
var busyTimeout int
|
||||
require.NoError(t, row.Scan(&busyTimeout))
|
||||
assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling")
|
||||
}
|
||||
|
||||
// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator
|
||||
// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE
|
||||
// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore.
|
||||
func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
envFile string
|
||||
expected int
|
||||
}{
|
||||
{name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000},
|
||||
{name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile)
|
||||
dir := t.TempDir()
|
||||
store, err := NewSqliteStore(context.Background(), dir, nil, true)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = store.Close(context.Background())
|
||||
})
|
||||
|
||||
sqlDB, err := store.db.DB()
|
||||
require.NoError(t, err)
|
||||
row := sqlDB.QueryRow("PRAGMA busy_timeout")
|
||||
var busyTimeout int
|
||||
require.NoError(t, row.Scan(&busyTimeout))
|
||||
assert.Equal(t, tc.expected, busyTimeout)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,6 +167,21 @@ type Store interface {
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
|
||||
// given session token, but only when the stored SessionStartedAt is
|
||||
// strictly less than newSessionStartedAt (the sentinel zero counts as
|
||||
// "older"). LastSeen is recorded by the database at the moment the
|
||||
// row is updated — never by the caller — so it always reflects the
|
||||
// real write time even under lock contention.
|
||||
// Returns true when the update happened, false when this stream lost
|
||||
// the race against a newer session.
|
||||
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
|
||||
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
|
||||
// resets SessionStartedAt to zero, but only when the stored
|
||||
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
|
||||
// recorded by the database. Returns true when the update happened,
|
||||
// false when a newer session has taken over.
|
||||
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||
@@ -213,7 +228,8 @@ type Store interface {
|
||||
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
||||
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
|
||||
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
||||
SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||
CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||
UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||
DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
|
||||
|
||||
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
||||
@@ -292,7 +308,7 @@ type Store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -456,6 +472,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[types.User](ctx, db, "email", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[nbpeer.Peer](ctx, db, "peer_status_session_started_at", int64(0))
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
||||
},
|
||||
|
||||
@@ -310,6 +310,20 @@ func (mr *MockStoreMockRecorder) CreateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroups", reflect.TypeOf((*MockStore)(nil).CreateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// CreateNetworkRouter mocks base method.
|
||||
func (m *MockStore) CreateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateNetworkRouter", ctx, router)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateNetworkRouter indicates an expected call of CreateNetworkRouter.
|
||||
func (mr *MockStoreMockRecorder) CreateNetworkRouter(ctx, router interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNetworkRouter", reflect.TypeOf((*MockStore)(nil).CreateNetworkRouter), ctx, router)
|
||||
}
|
||||
|
||||
// CreatePeerJob mocks base method.
|
||||
func (m *MockStore) CreatePeerJob(ctx context.Context, job *types2.Job) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -380,6 +394,20 @@ func (mr *MockStoreMockRecorder) DeleteAccount(ctx, account interface{}) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccount", reflect.TypeOf((*MockStore)(nil).DeleteAccount), ctx, account)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteCustomDomain mocks base method.
|
||||
func (m *MockStore) DeleteCustomDomain(ctx context.Context, accountID, domainID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -577,20 +605,6 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteRoute mocks base method.
|
||||
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -731,6 +745,20 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// EphemeralServiceExists mocks base method.
|
||||
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1332,21 +1360,6 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2048,6 +2061,21 @@ func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetProxyClusters mocks base method.
|
||||
func (m *MockStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyClusters", ctx, accountID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyClusters indicates an expected call of GetProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2598,6 +2626,36 @@ func (mr *MockStoreMockRecorder) MarkPATUsed(ctx, patID interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPATUsed", reflect.TypeOf((*MockStore)(nil).MarkPATUsed), ctx, patID)
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession mocks base method.
|
||||
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession mocks base method.
|
||||
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPendingJobsAsFailed mocks base method.
|
||||
func (m *MockStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2808,20 +2866,6 @@ func (mr *MockStoreMockRecorder) SaveNetworkResource(ctx, resource interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkResource", reflect.TypeOf((*MockStore)(nil).SaveNetworkResource), ctx, resource)
|
||||
}
|
||||
|
||||
// SaveNetworkRouter mocks base method.
|
||||
func (m *MockStore) SaveNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveNetworkRouter", ctx, router)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveNetworkRouter indicates an expected call of SaveNetworkRouter.
|
||||
func (mr *MockStoreMockRecorder) SaveNetworkRouter(ctx, router interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkRouter", reflect.TypeOf((*MockStore)(nil).SaveNetworkRouter), ctx, router)
|
||||
}
|
||||
|
||||
// SavePAT mocks base method.
|
||||
func (m *MockStore) SavePAT(ctx context.Context, pat *types2.PersonalAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2920,20 +2964,6 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3143,6 +3173,20 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// UpdateNetworkRouter mocks base method.
|
||||
func (m *MockStore) UpdateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateNetworkRouter", ctx, router)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateNetworkRouter indicates an expected call of UpdateNetworkRouter.
|
||||
func (mr *MockStoreMockRecorder) UpdateNetworkRouter(ctx, router interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNetworkRouter", reflect.TypeOf((*MockStore)(nil).UpdateNetworkRouter), ctx, router)
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -16,6 +16,8 @@ type AccountManagerMetrics struct {
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
peerStatusUpdateCounter metric.Int64Counter
|
||||
peerStatusUpdateDurationMs metric.Float64Histogram
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
@@ -64,6 +66,24 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected
|
||||
peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000,
|
||||
),
|
||||
metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
@@ -71,10 +91,35 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
updateAccountPeersCounter: updateAccountPeersCounter,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
peerStatusUpdateCounter: peerStatusUpdateCounter,
|
||||
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// PeerStatusOperation labels the kind of fence-locked peer status write.
|
||||
type PeerStatusOperation string
|
||||
|
||||
// PeerStatusOutcome labels how a fence-locked peer status write resolved.
|
||||
type PeerStatusOutcome string
|
||||
|
||||
const (
|
||||
PeerStatusConnect PeerStatusOperation = "connect"
|
||||
PeerStatusDisconnect PeerStatusOperation = "disconnect"
|
||||
|
||||
// PeerStatusApplied — the fence WHERE matched and the UPDATE landed.
|
||||
PeerStatusApplied PeerStatusOutcome = "applied"
|
||||
// PeerStatusStale — the fence WHERE rejected the write because a
|
||||
// newer session has already taken ownership (connect: stored token
|
||||
// >= incoming; disconnect: stored token != incoming).
|
||||
PeerStatusStale PeerStatusOutcome = "stale"
|
||||
// PeerStatusError — the store returned a non-NotFound error.
|
||||
PeerStatusError PeerStatusOutcome = "error"
|
||||
// PeerStatusPeerNotFound — the peer lookup failed (the peer was
|
||||
// deleted between the gRPC sync handshake and the status write).
|
||||
PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found"
|
||||
)
|
||||
|
||||
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
@@ -104,3 +149,23 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountPeerStatusUpdate increments the connect/disconnect counter,
|
||||
// labeled by operation and outcome. Both labels are bounded enums.
|
||||
func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) {
|
||||
metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("operation", string(op)),
|
||||
attribute.String("outcome", string(outcome)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// RecordPeerStatusUpdateDuration records the wall-clock time spent
|
||||
// running a peer status update (including post-write side effects),
|
||||
// labeled by operation.
|
||||
func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) {
|
||||
metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6,
|
||||
metric.WithAttributes(attribute.String("operation", string(op))),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type MockAppMetrics struct {
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||
EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||
@@ -103,6 +104,14 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface
|
||||
func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
|
||||
if mock.EphemeralPeersMetricsFunc != nil {
|
||||
return mock.EphemeralPeersMetricsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppMetrics is metrics interface
|
||||
type AppMetrics interface {
|
||||
GetMeter() metric2.Meter
|
||||
@@ -114,6 +123,7 @@ type AppMetrics interface {
|
||||
StoreMetrics() *StoreMetrics
|
||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||
AccountManagerMetrics() *AccountManagerMetrics
|
||||
EphemeralPeersMetrics() *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||
@@ -129,6 +139,7 @@ type defaultAppMetrics struct {
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
accountManagerMetrics *AccountManagerMetrics
|
||||
ephemeralMetrics *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// IDPMetrics returns metrics for the idp package
|
||||
@@ -161,6 +172,11 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr
|
||||
return appMetrics.accountManagerMetrics
|
||||
}
|
||||
|
||||
// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop
|
||||
func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
|
||||
return appMetrics.ephemeralMetrics
|
||||
}
|
||||
|
||||
// Close stop application metrics HTTP handler and closes listener.
|
||||
func (appMetrics *defaultAppMetrics) Close() error {
|
||||
if appMetrics.listener == nil {
|
||||
@@ -245,6 +261,11 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
@@ -254,6 +275,7 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
ephemeralMetrics: ephemeralMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -290,6 +312,11 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
@@ -300,5 +327,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
ephemeralMetrics: ephemeralMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
115
management/server/telemetry/ephemeral_metrics.go
Normal file
115
management/server/telemetry/ephemeral_metrics.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
|
||||
// many peers are currently scheduled for deletion, how many tick runs
|
||||
// the cleaner has performed, how many peers it has removed, and how
|
||||
// many delete batches failed.
|
||||
type EphemeralPeersMetrics struct {
|
||||
ctx context.Context
|
||||
|
||||
pending metric.Int64UpDownCounter
|
||||
cleanupRuns metric.Int64Counter
|
||||
peersCleaned metric.Int64Counter
|
||||
errors metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
|
||||
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
|
||||
pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &EphemeralPeersMetrics{
|
||||
ctx: ctx,
|
||||
pending: pending,
|
||||
cleanupRuns: cleanupRuns,
|
||||
peersCleaned: peersCleaned,
|
||||
errors: errors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// All methods are nil-receiver safe so callers that haven't wired metrics
|
||||
// (tests, self-hosted with metrics off) can invoke them unconditionally.
|
||||
|
||||
// IncPending bumps the pending gauge when a peer is added to the cleanup list.
|
||||
func (m *EphemeralPeersMetrics) IncPending() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// AddPending bumps the pending gauge by n — used at startup when the
|
||||
// initial set of ephemeral peers is loaded from the store.
|
||||
func (m *EphemeralPeersMetrics) AddPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// DecPending decreases the pending gauge — used both when a peer reconnects
|
||||
// before its deadline (removed from the list) and when a cleanup tick
|
||||
// actually deletes it.
|
||||
func (m *EphemeralPeersMetrics) DecPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, -n)
|
||||
}
|
||||
|
||||
// CountCleanupRun records one cleanup pass that processed >0 peers. Idle
|
||||
// ticks (nothing to do) deliberately don't increment so the rate
|
||||
// reflects useful work.
|
||||
func (m *EphemeralPeersMetrics) CountCleanupRun() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.cleanupRuns.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// CountPeersCleaned records the number of peers a single tick deleted.
|
||||
func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.peersCleaned.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// CountCleanupError records a failed delete batch.
|
||||
func (m *EphemeralPeersMetrics) CountCleanupError() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.errors.Add(m.ctx, 1)
|
||||
}
|
||||
4
management/server/testdata/networks.sql
vendored
4
management/server/testdata/networks.sql
vendored
@@ -9,9 +9,13 @@ INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM
|
||||
|
||||
CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description');
|
||||
INSERT INTO networks VALUES('secondNetworkId','testAccountId','second-name','second-description');
|
||||
|
||||
CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999);
|
||||
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.65.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO networks VALUES('otherNetworkId','otherAccountId','other-net','other-description');
|
||||
INSERT INTO network_routers VALUES('otherRouterId','otherNetworkId','otherAccountId','otherPeer',NULL,0,1);
|
||||
|
||||
CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32');
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user