mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-05 07:29:55 +00:00
Compare commits
55 Commits
refactor/m
...
feature/af
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad3989b34e | ||
|
|
2102349574 | ||
|
|
08b25e1d19 | ||
|
|
c2471510b6 | ||
|
|
42e25ad602 | ||
|
|
54fc223ba6 | ||
|
|
9189625487 | ||
|
|
e9dbf9db6f | ||
|
|
5a9e9e7bc9 | ||
|
|
43e041cf9f | ||
|
|
77e5693200 | ||
|
|
174dc24867 | ||
|
|
31aa39572d | ||
|
|
7ea5e37dd4 | ||
|
|
9d7ef9b255 | ||
|
|
944a258459 | ||
|
|
1f9a829f2c | ||
|
|
77604bb467 | ||
|
|
1072f73a73 | ||
|
|
4b6721c878 | ||
|
|
0b78e310b7 | ||
|
|
68b942722c | ||
|
|
7af7630e5b | ||
|
|
12361e5479 | ||
|
|
d3ae81e601 | ||
|
|
5c3f2ab0df | ||
|
|
ba554a73d0 | ||
|
|
9e236ac20e | ||
|
|
c948d7398f | ||
|
|
13d26106f8 | ||
|
|
3e83164bcd | ||
|
|
6568c905c6 | ||
|
|
aa9a1a42f5 | ||
|
|
5ae6c25ac0 | ||
|
|
1d906e411d | ||
|
|
3012228b91 | ||
|
|
85851bc477 | ||
|
|
fed4f1b024 | ||
|
|
70e84d5228 | ||
|
|
57529c7f18 | ||
|
|
fd99bc072d | ||
|
|
40e6ec16c6 | ||
|
|
ec476d5072 | ||
|
|
550ae5558e | ||
|
|
46494bd860 | ||
|
|
c7bff8f074 | ||
|
|
3a95f39f2c | ||
|
|
6b4d4076f4 | ||
|
|
63d2217d8a | ||
|
|
0bfccd65d2 | ||
|
|
26d778374b | ||
|
|
5ec8bebfa5 | ||
|
|
cefb37e920 | ||
|
|
5a16c812fd | ||
|
|
285bbc5ffb |
45
.github/dependabot.yml
vendored
Normal file
45
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
ignore:
|
||||
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
|
||||
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
|
||||
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
|
||||
- dependency-name: "git-town/action"
|
||||
update-types:
|
||||
- "version-update:semver-minor"
|
||||
- "version-update:semver-major"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
aws-sdk:
|
||||
patterns:
|
||||
- "github.com/aws/aws-sdk-go-v2/*"
|
||||
pion:
|
||||
patterns:
|
||||
- "github.com/pion/*"
|
||||
gorm:
|
||||
patterns:
|
||||
- "gorm.io/*"
|
||||
otel:
|
||||
patterns:
|
||||
- "go.opentelemetry.io/*"
|
||||
testcontainers:
|
||||
patterns:
|
||||
- "github.com/testcontainers/testcontainers-go/*"
|
||||
wireguard:
|
||||
patterns:
|
||||
- "golang.zx2c4.com/wireguard*"
|
||||
109
.github/workflows/check-license-dependencies.yml
vendored
109
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
@@ -19,7 +19,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
@@ -56,55 +59,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
|
||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,11 +8,10 @@ jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
discourse-tags: releases
|
||||
|
||||
8
.github/workflows/git-town.yml
vendored
8
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
- "**"
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
@@ -15,7 +15,9 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
|
||||
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,16 +16,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -44,4 +46,3 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
21
.github/workflows/golang-test-freebsd.yml
vendored
21
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,20 +15,31 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
|
||||
138
.github/workflows/golang-test-linux.yml
vendored
138
.github/workflows/golang-test-linux.yml
vendored
@@ -18,9 +18,11 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -28,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -36,10 +38,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -113,14 +115,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -128,10 +132,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -158,14 +162,16 @@ jobs:
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -177,7 +183,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -231,10 +237,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,10 +254,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -277,14 +285,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -298,7 +308,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -324,14 +334,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -343,10 +355,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -370,19 +382,21 @@ jobs:
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -390,10 +404,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -410,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -427,7 +441,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
@@ -437,13 +451,13 @@ jobs:
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -474,10 +488,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -485,10 +501,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -505,7 +521,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -529,13 +545,13 @@ jobs:
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -566,10 +582,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -577,10 +595,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -597,7 +615,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -623,20 +641,22 @@ jobs:
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -644,10 +664,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
19
.github/workflows/golang-test-windows.yml
vendored
19
.github/workflows/golang-test-windows.yml
vendored
@@ -18,10 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -33,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -44,16 +46,15 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
destination: ${{ env.downloadPath }}\wintun.zip
|
||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
||||
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
|
||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -15,9 +15,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
@@ -38,13 +40,15 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -52,7 +56,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
|
||||
4
.github/workflows/install-script-test.yml
vendored
4
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,23 +16,25 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -52,9 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
|
||||
2
.github/workflows/proto-version-check.yml
vendored
2
.github/workflows/proto-version-check.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
|
||||
168
.github/workflows/release.yml
vendored
168
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
SIGN_PIPE_VER: "v0.1.5"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -24,7 +24,9 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
@@ -51,19 +53,26 @@ jobs:
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
pkg install -y git curl portlint
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -93,19 +102,19 @@ jobs:
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
@@ -124,26 +133,25 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -156,18 +164,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -191,7 +199,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -282,28 +290,28 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -314,27 +322,26 @@ jobs:
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -375,7 +382,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -404,7 +411,7 @@ jobs:
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -418,16 +425,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -441,7 +449,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -449,7 +457,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
@@ -474,27 +482,26 @@ jobs:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
@@ -514,29 +521,27 @@ jobs:
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
destination: ${{ env.downloadPath }}\wintun.zip
|
||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
|
||||
destination: ${{ env.downloadPath }}\mesa3d.7z
|
||||
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -547,35 +552,38 @@ jobs:
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
|
||||
destination: ${{ github.workspace }}\envar_plugin.zip
|
||||
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
|
||||
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
|
||||
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
shell: pwsh
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
run: |
|
||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
||||
Copy-Item -Destination $nsisPluginDir -Force
|
||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
@@ -592,7 +600,7 @@ jobs:
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
@@ -611,7 +619,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
@@ -703,7 +711,7 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
|
||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
|
||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
@@ -42,10 +42,10 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
26
.github/workflows/test-infrastructure-files.yml
vendored
26
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
- "infrastructure_files/**"
|
||||
- ".github/workflows/test-infrastructure-files.yml"
|
||||
- "management/cmd/**"
|
||||
- "signal/cmd/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -68,15 +68,17 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -139,8 +141,8 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
@@ -254,7 +256,9 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
paths:
|
||||
- 'shared/management/http/api/openapi.yml'
|
||||
- "shared/management/http/api/openapi.yml"
|
||||
|
||||
jobs:
|
||||
trigger_docs_api_update:
|
||||
@@ -13,10 +13,10 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger API pages generation
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: generate api pages
|
||||
repo: netbirdio/docs
|
||||
ref: "refs/heads/main"
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
15
.github/workflows/wasm-build-validation.yml
vendored
15
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,15 +19,17 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
@@ -42,9 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
@@ -65,4 +69,3 @@ jobs:
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -360,7 +360,13 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
|
||||
// alias by default, ensure explicit ip for localhost.
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
addr := net.JoinHostPort(host, port)
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
}
|
||||
|
||||
// Fallback to deterministic key if no NetBird PSK is configured
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return nil
|
||||
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
return determKey
|
||||
}
|
||||
|
||||
// todo: move this logic into Rosenpass package
|
||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
||||
lk := []byte(conn.config.LocalKey)
|
||||
rk := []byte(conn.config.Key) // remote key
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -179,8 +179,10 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||
}
|
||||
|
||||
dst := net.IPv4zero
|
||||
if runtime.GOOS == "linux" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
|
||||
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
|
||||
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
|
||||
dst = net.IPv4(0, 0, 0, 1)
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
@@ -203,7 +205,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||
}
|
||||
|
||||
dst := net.IPv6zero
|
||||
if runtime.GOOS == "linux" {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
// ::2
|
||||
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
|
||||
// so tests can substitute a mock without spinning up a real UDP server.
|
||||
type rpServer interface {
|
||||
AddPeer(rp.PeerConfig) (rp.PeerID, error)
|
||||
RemovePeer(rp.PeerID) error
|
||||
Run() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
ifaceName string
|
||||
spk []byte
|
||||
@@ -36,7 +45,7 @@ type Manager struct {
|
||||
preSharedKey *[32]byte
|
||||
rpPeerIDs map[string]*rp.PeerID
|
||||
rpWgHandler *NetbirdHandler
|
||||
server *rp.Server
|
||||
server rpServer
|
||||
lock sync.Mutex
|
||||
port int
|
||||
wgIface PresharedKeySetter
|
||||
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
||||
|
||||
rpKeyHash := hashRosenpassKey(public)
|
||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||
return &Manager{
|
||||
ifaceName: wgIfaceName,
|
||||
rpKeyHash: rpKeyHash,
|
||||
spk: public,
|
||||
ssk: secret,
|
||||
preSharedKey: (*[32]byte)(preSharedKey),
|
||||
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||
// rpWgHandler is created here (instead of only in generateConfig) so it
|
||||
// is never nil between NewManager and Run(). Otherwise an early
|
||||
// OnConnected call (race observed on Android, issue #4341) panics on
|
||||
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
|
||||
// replace it with a fresh handler on each Run() to clear stale peer
|
||||
// state from previous engine sessions.
|
||||
rpWgHandler: NewNetbirdHandler(),
|
||||
lock: sync.Mutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetPubKey() []byte {
|
||||
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
|
||||
|
||||
// addPeer adds a new peer to the Rosenpass server
|
||||
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
||||
// Defense in depth against issue #4341 (Android crash): if Run() has not
|
||||
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
|
||||
// error instead of panicking on nil-receiver dereference.
|
||||
if m.server == nil {
|
||||
return fmt.Errorf("rosenpass server not initialized")
|
||||
}
|
||||
if m.rpWgHandler == nil {
|
||||
return fmt.Errorf("rosenpass wg handler not initialized")
|
||||
}
|
||||
|
||||
var err error
|
||||
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
||||
if m.preSharedKey != nil {
|
||||
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
|
||||
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
||||
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
||||
}
|
||||
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
|
||||
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
|
||||
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
|
||||
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
|
||||
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
|
||||
// IPv6 so its address family matches our listening socket.
|
||||
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
|
||||
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
|
||||
pcfg.Endpoint.IP = v4.To16()
|
||||
}
|
||||
}
|
||||
peerID, err := m.server.AddPeer(pcfg)
|
||||
if err != nil {
|
||||
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.server, err = rp.NewUDPServer(conf)
|
||||
server, err := rp.NewUDPServer(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
m.server = server
|
||||
m.lock.Unlock()
|
||||
|
||||
log.Infof("starting rosenpass server on port %d", m.port)
|
||||
|
||||
return m.server.Run()
|
||||
return server.Run()
|
||||
}
|
||||
|
||||
// Close closes the Rosenpass server
|
||||
func (m *Manager) Close() error {
|
||||
if m.server != nil {
|
||||
err := m.server.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing local rosenpass server")
|
||||
}
|
||||
m.server = nil
|
||||
m.lock.Lock()
|
||||
server := m.server
|
||||
m.server = nil
|
||||
m.lock.Unlock()
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
if err := server.Close(); err != nil {
|
||||
log.Errorf("failed closing local rosenpass server: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,412 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
rp "cunicu.li/go-rosenpass"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// --- test doubles -----------------------------------------------------------
|
||||
|
||||
type addPeerCall struct {
|
||||
cfg rp.PeerConfig
|
||||
}
|
||||
|
||||
type removePeerCall struct {
|
||||
id rp.PeerID
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
}
|
||||
|
||||
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
|
||||
if m.addErr != nil {
|
||||
return rp.PeerID{}, m.addErr
|
||||
}
|
||||
// Increment a byte in nextID so distinct peers get distinct IDs.
|
||||
m.nextID[0]++
|
||||
return m.nextID, nil
|
||||
}
|
||||
|
||||
func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.removed = append(m.removed, removePeerCall{id: id})
|
||||
return m.removeErr
|
||||
}
|
||||
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||
|
||||
type setPSKCall struct {
|
||||
peerKey string
|
||||
psk wgtypes.Key
|
||||
updateOnly bool
|
||||
}
|
||||
|
||||
type mockIface struct {
|
||||
mu sync.Mutex
|
||||
calls []setPSKCall
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
|
||||
return m.err
|
||||
}
|
||||
|
||||
// newTestManager builds a Manager with deterministic spk so tie-break
|
||||
// against a peer pubkey is controllable from tests. The provided spk byte
|
||||
// becomes the first byte; remaining bytes are zero.
|
||||
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
|
||||
spk := make([]byte, 32)
|
||||
spk[0] = spkFirstByte
|
||||
return &Manager{
|
||||
ifaceName: "wt0",
|
||||
spk: spk,
|
||||
ssk: make([]byte, 32),
|
||||
rpKeyHash: "test-hash",
|
||||
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||
rpWgHandler: NewNetbirdHandler(),
|
||||
server: mock,
|
||||
}
|
||||
}
|
||||
|
||||
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
|
||||
func validWGKey(t *testing.T, lastByte byte) string {
|
||||
t.Helper()
|
||||
var k wgtypes.Key
|
||||
k[31] = lastByte
|
||||
return k.String()
|
||||
}
|
||||
|
||||
// --- pure helpers ----------------------------------------------------------
|
||||
|
||||
func TestHashRosenpassKey_Deterministic(t *testing.T) {
|
||||
key := []byte("hello-rosenpass")
|
||||
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
|
||||
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
|
||||
}
|
||||
|
||||
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
|
||||
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
|
||||
}
|
||||
|
||||
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
|
||||
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
|
||||
// can only set, not delete, so do it manually with restore via t.Cleanup.
|
||||
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
|
||||
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
|
||||
t.Cleanup(func() {
|
||||
if hadPrev {
|
||||
_ = os.Setenv(defaultLogLevelVar, prev)
|
||||
} else {
|
||||
_ = os.Unsetenv(defaultLogLevelVar)
|
||||
}
|
||||
})
|
||||
require.Equal(t, defaultLog.String(), getLogLevel().String())
|
||||
}
|
||||
|
||||
func TestGetLogLevel_Cases(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"debug": "DEBUG",
|
||||
"info": "INFO",
|
||||
"warn": "WARN",
|
||||
"error": "ERROR",
|
||||
"unknown": "INFO", // default fallback
|
||||
}
|
||||
for input, wantStr := range cases {
|
||||
input, wantStr := input, wantStr
|
||||
t.Run(input, func(t *testing.T) {
|
||||
t.Setenv(defaultLogLevelVar, input)
|
||||
require.Equal(t, wantStr, getLogLevel().String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||
port, err := findRandomAvailableUDPPort()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, port, 0)
|
||||
require.LessOrEqual(t, port, 65535)
|
||||
}
|
||||
|
||||
// --- addPeer ---------------------------------------------------------------
|
||||
|
||||
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv) // local spk lexicographically larger
|
||||
|
||||
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
|
||||
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, srv.addCalls, 1)
|
||||
|
||||
ep := srv.addCalls[0].cfg.Endpoint
|
||||
require.NotNil(t, ep, "initiator side must set Endpoint")
|
||||
require.Equal(t, 7000, ep.Port)
|
||||
require.Equal(t, "100.1.1.1", ep.IP.String())
|
||||
}
|
||||
|
||||
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
|
||||
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
|
||||
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.NoError(t, err)
|
||||
|
||||
ep := srv.addCalls[0].cfg.Endpoint
|
||||
require.NotNil(t, ep)
|
||||
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
|
||||
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
|
||||
}
|
||||
|
||||
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0x00, srv) // local spk smaller
|
||||
|
||||
remotePubKey := make([]byte, 32)
|
||||
remotePubKey[0] = 0xFF
|
||||
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
|
||||
}
|
||||
|
||||
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
psk := &wgtypes.Key{0x42}
|
||||
m := newTestManager(0xFF, srv)
|
||||
m.preSharedKey = (*[32]byte)(psk)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
|
||||
}
|
||||
|
||||
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
|
||||
}
|
||||
|
||||
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAddPeer_ServerError_Propagates(t *testing.T) {
|
||||
srv := &mockServer{addErr: errors.New("boom")}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// Regression guard for issue #4341 (Android crash). If Run() has not completed
|
||||
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
|
||||
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
|
||||
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
m.rpWgHandler = nil // simulate Run() not yet completed
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "wg handler not initialized")
|
||||
}
|
||||
|
||||
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
|
||||
m := newTestManager(0xFF, nil)
|
||||
m.server = nil // simulate Run() not yet completed
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "server not initialized")
|
||||
}
|
||||
|
||||
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
|
||||
// issue #4341 cannot occur in the window between NewManager and Run().
|
||||
func TestNewManager_PreInitializesHandler(t *testing.T) {
|
||||
psk := wgtypes.Key{}
|
||||
m, err := NewManager(&psk, "wt0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
|
||||
}
|
||||
|
||||
func TestAddPeer_RecordsPeerID(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 5)
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
}
|
||||
|
||||
// --- OnConnected / OnDisconnected ------------------------------------------
|
||||
|
||||
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
|
||||
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
|
||||
require.Empty(t, m.rpPeerIDs)
|
||||
}
|
||||
|
||||
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 1)
|
||||
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
|
||||
require.Len(t, srv.addCalls, 1)
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
}
|
||||
|
||||
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
m.OnDisconnected(validWGKey(t, 99))
|
||||
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
|
||||
}
|
||||
|
||||
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 1)
|
||||
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
|
||||
m.OnDisconnected(wgKey)
|
||||
require.Len(t, srv.removed, 1)
|
||||
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
|
||||
}
|
||||
|
||||
// --- IsPresharedKeyInitialized ---------------------------------------------
|
||||
|
||||
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
|
||||
}
|
||||
|
||||
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 2)
|
||||
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||
require.False(t, m.IsPresharedKeyInitialized(wgKey))
|
||||
}
|
||||
|
||||
// --- NetbirdHandler.outputKey ----------------------------------------------
|
||||
|
||||
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x01}
|
||||
wgKey := wgtypes.Key{0xAA}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||
|
||||
psk := rp.Key{0xBB}
|
||||
h.HandshakeCompleted(pid, psk)
|
||||
|
||||
require.Len(t, iface.calls, 1)
|
||||
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
|
||||
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x02}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
|
||||
|
||||
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
|
||||
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
|
||||
|
||||
require.Len(t, iface.calls, 2)
|
||||
require.False(t, iface.calls[0].updateOnly)
|
||||
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
// no SetInterface — iface remains nil
|
||||
pid := rp.PeerID{0x03}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
|
||||
|
||||
// Must not panic.
|
||||
h.HandshakeCompleted(pid, rp.Key{})
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
|
||||
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
|
||||
}
|
||||
|
||||
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x04}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
|
||||
h.HandshakeCompleted(pid, rp.Key{0x01})
|
||||
require.True(t, h.IsPeerInitialized(pid))
|
||||
|
||||
h.RemovePeer(pid)
|
||||
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
|
||||
}
|
||||
|
||||
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
pid := rp.PeerID{0x05}
|
||||
wgKey := wgtypes.Key{0xEE}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface) // set after AddPeer
|
||||
|
||||
h.HandshakeCompleted(pid, rp.Key{0x42})
|
||||
require.Len(t, iface.calls, 1)
|
||||
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||
}
|
||||
|
||||
42
client/internal/rosenpass/seed.go
Normal file
42
client/internal/rosenpass/seed.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
|
||||
// of peer public keys. Both peers, given the same key pair, produce the same
|
||||
// output regardless of which side runs the function: the inputs are ordered
|
||||
// lexicographically before concatenation.
|
||||
//
|
||||
// NetBird uses this value as the initial Rosenpass-side preshared key when no
|
||||
// explicit account-level PSK is configured, so both peers converge on the same
|
||||
// PSK before the first post-quantum handshake completes.
|
||||
//
|
||||
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
|
||||
// from public keys and exists only to seed WireGuard until Rosenpass rotates
|
||||
// in a real post-quantum PSK.
|
||||
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
|
||||
lk := []byte(localKey)
|
||||
rk := []byte(remoteKey)
|
||||
if len(lk) < 16 || len(rk) < 16 {
|
||||
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
|
||||
}
|
||||
|
||||
var keyInput []byte
|
||||
if localKey > remoteKey {
|
||||
keyInput = append(keyInput, lk[:16]...)
|
||||
keyInput = append(keyInput, rk[:16]...)
|
||||
} else {
|
||||
keyInput = append(keyInput, rk[:16]...)
|
||||
keyInput = append(keyInput, lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
43
client/internal/rosenpass/seed_test.go
Normal file
43
client/internal/rosenpass/seed_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
|
||||
// Peer A and peer B must derive the same PSK regardless of which side
|
||||
// computes it: the function orders inputs internally.
|
||||
a := strings.Repeat("a", 32)
|
||||
b := strings.Repeat("b", 32)
|
||||
|
||||
keyAB, err := DeterministicSeedKey(a, b)
|
||||
require.NoError(t, err)
|
||||
keyBA, err := DeterministicSeedKey(b, a)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
|
||||
}
|
||||
|
||||
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
|
||||
a := strings.Repeat("a", 32)
|
||||
b := strings.Repeat("b", 32)
|
||||
c := strings.Repeat("c", 32)
|
||||
|
||||
keyAB, err := DeterministicSeedKey(a, b)
|
||||
require.NoError(t, err)
|
||||
keyAC, err := DeterministicSeedKey(a, c)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
|
||||
}
|
||||
|
||||
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||
short := "short" // < 16 bytes
|
||||
long := strings.Repeat("x", 32)
|
||||
|
||||
_, err := DeterministicSeedKey(short, long)
|
||||
require.Error(t, err)
|
||||
_, err = DeterministicSeedKey(long, short)
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
cancel := m.cancel
|
||||
done := m.done
|
||||
m.mu.Unlock()
|
||||
|
||||
if m.cancel == nil {
|
||||
if cancel == nil {
|
||||
return nil
|
||||
}
|
||||
m.cancel()
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
case <-done:
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -67,6 +67,10 @@ func init() {
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
@@ -168,7 +172,7 @@ func initializeConfig() error {
|
||||
// serverInstances holds all server instances created during startup.
|
||||
type serverInstances struct {
|
||||
relaySrv *relayServer.Server
|
||||
mgmtSrv *mgmtServer.BaseServer
|
||||
mgmtSrv mgmtServer.Server
|
||||
signalSrv *signalServer.Server
|
||||
healthcheck *healthcheck.Server
|
||||
stunServer *stun.Server
|
||||
@@ -324,19 +328,24 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok {
|
||||
if baseServer, ok := s.(*mgmtServer.BaseServer); ok {
|
||||
baseServer.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||
grpcSrv := s.GRPCServer()
|
||||
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
if servers.signalSrv != nil {
|
||||
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
|
||||
@@ -346,38 +355,32 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *
|
||||
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
if stunServer != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
if err := stunServer.Listen(); err != nil {
|
||||
if errors.Is(err, stun.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
log.Errorf("STUN server error: %v", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
|
||||
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error {
|
||||
var errs error
|
||||
|
||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||
@@ -491,7 +494,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) {
|
||||
mgmt := cfg.Management
|
||||
|
||||
// Extract port from listen address
|
||||
@@ -502,7 +505,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
}
|
||||
mgmtPort, _ := strconv.Atoi(portStr)
|
||||
|
||||
mgmtSrv := mgmtServer.NewServer(
|
||||
mgmtSrv := newServer(
|
||||
&mgmtServer.Config{
|
||||
NbConfig: mgmtConfig,
|
||||
DNSDomain: "",
|
||||
|
||||
13
combined/cmd/server.go
Normal file
13
combined/cmd/server.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
|
||||
)
|
||||
|
||||
var newServer = func(cfg *mgmtServer.Config) mgmtServer.Server {
|
||||
return mgmtServer.NewServer(cfg)
|
||||
}
|
||||
|
||||
func SetNewServer(fn func(*mgmtServer.Config) mgmtServer.Server) {
|
||||
newServer = fn
|
||||
}
|
||||
10
go.mod
10
go.mod
@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.4.0
|
||||
cunicu.li/go-rosenpass v0.5.42
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
github.com/cloudflare/circl v1.3.3 // indirect
|
||||
github.com/golang/protobuf v1.5.4
|
||||
@@ -19,8 +19,8 @@ require (
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/sys v0.43.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
@@ -38,7 +38,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/cilium/ebpf v0.19.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/coreos/go-oidc/v3 v3.18.0
|
||||
@@ -60,7 +60,7 @@ require (
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/nftables v0.3.0
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/gopacket/gopacket v1.4.0
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
|
||||
22
go.sum
22
go.sum
@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
|
||||
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
|
||||
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
|
||||
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
|
||||
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
|
||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
||||
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
|
||||
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
|
||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
|
||||
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
|
||||
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
|
||||
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
|
||||
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
|
||||
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
|
||||
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
|
||||
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
|
||||
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
|
||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
|
||||
@@ -44,7 +44,7 @@ type Controller struct {
|
||||
EphemeralPeersManager ephemeral.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
affectedPeerUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
@@ -63,6 +63,13 @@ type bufferUpdate struct {
|
||||
update atomic.Bool
|
||||
}
|
||||
|
||||
type bufferAffectedUpdate struct {
|
||||
sendMu sync.Mutex
|
||||
dataMu sync.Mutex
|
||||
next *time.Timer
|
||||
peerIDs map[string]struct{}
|
||||
}
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
@@ -112,7 +119,7 @@ func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -175,6 +182,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
continue
|
||||
}
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
@@ -196,7 +207,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
@@ -221,51 +232,150 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
|
||||
|
||||
if !c.hasConnectedPeers(peerIDs) {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
|
||||
if len(peersToUpdate) == 0 {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
||||
}
|
||||
|
||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range peersToUpdate {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
|
||||
for _, id := range peerIDs {
|
||||
if c.peersUpdateManager.HasChannel(id) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
|
||||
affected := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
affected[id] = struct{}{}
|
||||
}
|
||||
|
||||
var result []*nbpeer.Peer
|
||||
for _, peer := range account.Peers {
|
||||
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
@@ -359,14 +469,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -376,6 +486,100 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
})
|
||||
b := bufUpd.(*bufferAffectedUpdate)
|
||||
|
||||
b.addPeerIDs(peerIDs)
|
||||
|
||||
if !b.sendMu.TryLock() {
|
||||
// Another goroutine is already sending; it will pick up our IDs on its next drain.
|
||||
return nil
|
||||
}
|
||||
|
||||
b.stopTimer()
|
||||
|
||||
collected := b.drainPeerIDs()
|
||||
go func() {
|
||||
defer b.sendMu.Unlock()
|
||||
_ = c.sendUpdateForAffectedPeers(ctx, accountID, collected)
|
||||
|
||||
// Check if more peer IDs accumulated while we were sending.
|
||||
if !b.hasPending() {
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule a debounced flush for the newly accumulated IDs.
|
||||
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
ids := b.drainPeerIDs()
|
||||
if len(ids) > 0 {
|
||||
_ = c.sendUpdateForAffectedPeers(ctx, accountID, ids)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
|
||||
b.dataMu.Lock()
|
||||
for _, id := range ids {
|
||||
b.peerIDs[id] = struct{}{}
|
||||
}
|
||||
b.dataMu.Unlock()
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if len(b.peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0, len(b.peerIDs))
|
||||
for id := range b.peerIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
b.peerIDs = make(map[string]struct{})
|
||||
return ids
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) hasPending() bool {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
return len(b.peerIDs) > 0
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) stopTimer() {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(d, f)
|
||||
return
|
||||
}
|
||||
b.next.Reset(d)
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
@@ -573,21 +777,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -620,7 +827,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping network map update", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
|
||||
@@ -19,6 +19,8 @@ const (
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
|
||||
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
@@ -27,9 +29,9 @@ type Controller interface {
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
CountStreams() int
|
||||
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
|
||||
@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// CountStreams mocks base method.
|
||||
func (m *MockController) CountStreams() int {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -158,45 +172,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
|
||||
}
|
||||
|
||||
// OnPeersAdded mocks base method.
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersAdded indicates an expected call of OnPeersAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// OnPeersDeleted mocks base method.
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// OnPeersUpdated mocks base method.
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
@@ -250,3 +264,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
|
||||
@@ -75,7 +75,7 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
allowed, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
@@ -88,7 +88,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
allowed, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
|
||||
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, 0, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
|
||||
}
|
||||
|
||||
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
}
|
||||
|
||||
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
|
||||
}
|
||||
|
||||
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
|
||||
@@ -37,7 +37,7 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
@@ -76,13 +76,13 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||
if err := h.store.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toProxyTokenCreatedResponse(generated)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -92,7 +92,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
@@ -102,7 +102,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||
return
|
||||
@@ -113,7 +113,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
resp = append(resp, toProxyTokenResponse(token))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -123,7 +123,7 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
@@ -139,7 +139,7 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||
token, err := h.store.GetProxyAccessTokenByID(ctx, store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
@@ -154,12 +154,12 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||
if err := h.store.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
util.WriteJSONObject(ctx, w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestCreateToken_AccountScoped(t *testing.T) {
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
@@ -90,7 +90,7 @@ func TestCreateToken_WithExpiration(t *testing.T) {
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
@@ -115,7 +115,7 @@ func TestCreateToken_EmptyName(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
@@ -135,7 +135,7 @@ func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
@@ -164,7 +164,7 @@ func TestListTokens(t *testing.T) {
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
@@ -202,7 +202,7 @@ func TestRevokeToken_Success(t *testing.T) {
|
||||
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
@@ -231,7 +231,7 @@ func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
@@ -258,7 +258,7 @@ func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
|
||||
@@ -120,7 +120,7 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
// capability flags reported by its active proxies so the dashboard can
|
||||
// render feature support without a second round-trip.
|
||||
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
// owned by the account.
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, c
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -222,7 +222,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
}
|
||||
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -243,7 +243,7 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
|
||||
}
|
||||
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -528,7 +528,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -836,7 +836,7 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -876,7 +876,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -1172,7 +1172,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
mockPerms.EXPECT().
|
||||
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
|
||||
@@ -32,7 +32,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID str
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -151,7 +151,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
@@ -95,7 +95,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
@@ -112,7 +112,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
Return(false, ctx, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
@@ -134,7 +134,7 @@ func TestManagerImpl_GetZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
@@ -150,7 +150,7 @@ func TestManagerImpl_GetZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
@@ -179,7 +179,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
@@ -212,7 +212,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
@@ -235,7 +235,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
@@ -261,7 +261,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
@@ -293,7 +293,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
@@ -319,7 +319,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
@@ -354,7 +354,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
@@ -394,7 +394,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
@@ -418,7 +418,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
@@ -441,7 +441,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
@@ -471,7 +471,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
storeEventCallCount := 0
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
@@ -503,7 +503,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
@@ -529,7 +529,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
@@ -545,7 +545,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -32,7 +32,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zone
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID,
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -161,7 +161,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
@@ -96,7 +96,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
@@ -113,7 +113,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
Return(false, ctx, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
@@ -135,7 +135,7 @@ func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
@@ -153,7 +153,7 @@ func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
@@ -181,7 +181,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
@@ -215,7 +215,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
@@ -244,7 +244,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
@@ -273,7 +273,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
@@ -297,7 +297,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
@@ -323,7 +323,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
@@ -349,7 +349,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
@@ -380,7 +380,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
@@ -418,7 +418,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
// Event should be stored
|
||||
@@ -445,7 +445,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
@@ -470,7 +470,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
@@ -500,7 +500,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
@@ -523,7 +523,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
@@ -549,7 +549,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
Return(false, ctx, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
@@ -565,7 +565,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
Return(true, ctx, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -120,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.InstanceManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.AuthMiddleware())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -153,20 +153,6 @@ func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthMiddleware() mux.MiddlewareFunc {
|
||||
return Create(s, func() mux.MiddlewareFunc {
|
||||
m := middleware.NewAuthMiddleware(
|
||||
s.AuthManager(),
|
||||
s.AccountManager().GetAccountIDFromUserAuth,
|
||||
s.AccountManager().SyncUserJWTGroups,
|
||||
s.AccountManager().GetUserFromUserAuth,
|
||||
s.RateLimiter(),
|
||||
s.Metrics().GetMeter(),
|
||||
)
|
||||
return m.Handler
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
@@ -152,16 +151,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) InstanceManager() instance.Manager {
|
||||
return Create(s, func() instance.Manager {
|
||||
m, err := instance.NewManager(context.Background(), s.Store(), s.IdpManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create instance manager: %v", err)
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||
@@ -240,3 +229,6 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -34,6 +34,8 @@ const (
|
||||
ManagementLegacyPort = 33073
|
||||
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
|
||||
DefaultSelfHostedDomain = "netbird.selfhosted"
|
||||
|
||||
ContainerKeyBaseServer = "baseServer"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
@@ -91,7 +93,7 @@ type Config struct {
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(cfg *Config) *BaseServer {
|
||||
return &BaseServer{
|
||||
s := &BaseServer{
|
||||
Config: cfg.NbConfig,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: cfg.DNSDomain,
|
||||
@@ -104,6 +106,9 @@ func NewServer(cfg *Config) *BaseServer {
|
||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||
autoResolveDomains: cfg.AutoResolveDomains,
|
||||
}
|
||||
s.container[ContainerKeyBaseServer] = s
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
@@ -185,9 +187,38 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
// settings == nil → field stays nil → "no info in this snapshot", client
|
||||
// preserves the deadline it already had. settings non-nil → emit either a
|
||||
// valid deadline or the explicit-zero "disabled" sentinel via
|
||||
// encodeSessionExpiresAt.
|
||||
if settings != nil {
|
||||
response.SessionExpiresAt = encodeSessionExpiresAt(
|
||||
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
|
||||
// representation used on LoginResponse, SyncResponse and
|
||||
// ExtendAuthSessionResponse. See the proto comments on those messages.
|
||||
//
|
||||
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
|
||||
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
|
||||
// its anchor.
|
||||
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
|
||||
// UTC deadline.
|
||||
//
|
||||
// Returning nil ("no info, preserve client's anchor") is the caller's job —
|
||||
// only meaningful on Sync builds where settings were not resolved.
|
||||
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
|
||||
if deadline.IsZero() {
|
||||
return ×tamppb.Timestamp{}
|
||||
}
|
||||
return timestamppb.New(deadline)
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@@ -200,3 +201,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeSessionExpiresAt pins the wire encoding the client's
|
||||
// applySessionDeadline depends on:
|
||||
//
|
||||
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
|
||||
// "expiry disabled or peer is not SSO-tracked" sentinel.
|
||||
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
|
||||
//
|
||||
// The third state (nil pointer = "no info in this snapshot") is the caller's
|
||||
// responsibility on the Sync path when settings could not be resolved; the
|
||||
// helper itself never returns nil.
|
||||
func TestEncodeSessionExpiresAt(t *testing.T) {
|
||||
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
|
||||
got := encodeSessionExpiresAt(time.Time{})
|
||||
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
|
||||
assert.Equal(t, int64(0), got.GetSeconds())
|
||||
assert.Equal(t, int32(0), got.GetNanos())
|
||||
})
|
||||
|
||||
t.Run("non-zero deadline round-trips", func(t *testing.T) {
|
||||
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
got := encodeSessionExpiresAt(deadline)
|
||||
assert.NotNil(t, got)
|
||||
assert.True(t, got.AsTime().Equal(deadline))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
// Send immediately (first update or after quiet period)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
|
||||
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
|
||||
// stays up; no network map sync is performed. The new deadline is returned
|
||||
// in ExtendAuthSessionResponse.SessionExpiresAt.
|
||||
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
extendReq := &proto.ExtendAuthSessionRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, extendReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
}
|
||||
|
||||
jwt := extendReq.GetJwtToken()
|
||||
if jwt == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
|
||||
}
|
||||
|
||||
var userID string
|
||||
const attempts = 3
|
||||
for i := 0; i < attempts; i++ {
|
||||
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i == attempts-1 {
|
||||
break
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userID == "" {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
|
||||
}
|
||||
|
||||
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
// Success path normally returns a non-zero deadline. A defensive zero
|
||||
// would still encode as the explicit "disabled" sentinel rather than nil,
|
||||
// so the client clears any stale anchor instead of preserving it.
|
||||
resp := &proto.ExtendAuthSessionResponse{
|
||||
SessionExpiresAt: encodeSessionExpiresAt(deadline),
|
||||
}
|
||||
|
||||
wgKey, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed processing request")
|
||||
}
|
||||
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed encrypting response")
|
||||
}
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: wgKey.PublicKey().String(),
|
||||
Body: encrypted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
// settings is always non-nil here, so we never emit nil — encoder returns
|
||||
// either a valid deadline or the explicit-zero "disabled" sentinel.
|
||||
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
|
||||
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||
)
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
// User that performs the update has to belong to the account.
|
||||
// Returns an updated Settings
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
||||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
|
||||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
|
||||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
|
||||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
|
||||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||
// Session deadline is derived from LastLogin + PeerLoginExpiration
|
||||
// on every Login/Sync response. Without a fan-out push, connected
|
||||
// peers keep the deadline they received at login time and only see
|
||||
// the new value after the next unrelated NetworkMap change. Add
|
||||
// these two fields to the trigger list so admin-side expiry tweaks
|
||||
// (e.g. shortening from 24h to 1h) reach every connected peer
|
||||
// within seconds, which is what the proactive-warning feature
|
||||
// relies on (see client/internal/auth/sessionwatch).
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
@@ -845,7 +855,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
@@ -1412,7 +1422,7 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -1425,7 +1435,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
|
||||
// GetAccountMeta returns the account metadata associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -1438,7 +1448,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
|
||||
|
||||
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -1463,7 +1473,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
@@ -1530,7 +1540,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
ctx, err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -1976,7 +1987,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -2544,7 +2555,7 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate user permissions: %w", err)
|
||||
}
|
||||
@@ -2562,7 +2573,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -2634,7 +2647,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
|
||||
// UpdatePeerIPv6 updates the IPv6 overlay address of a peer, validating it's
|
||||
// within the account's v6 network range and not already taken.
|
||||
func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate user permissions: %w", err)
|
||||
}
|
||||
@@ -2653,7 +2666,9 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return fmt.Errorf("notify network map controller: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -109,6 +110,7 @@ type Manager interface {
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
@@ -127,6 +129,9 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
|
||||
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason)
|
||||
ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
activity "github.com/netbirdio/netbird/management/server/activity"
|
||||
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
idp "github.com/netbirdio/netbird/management/server/idp"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
posture "github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -122,6 +123,18 @@ func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reas
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers mocks base method.
|
||||
func (m *MockManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
|
||||
func (mr *MockManagerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// BuildUserInfosForAccount mocks base method.
|
||||
func (m *MockManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1304,6 +1317,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
|
||||
}
|
||||
|
||||
// ExtendPeerSession mocks base method.
|
||||
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
|
||||
ret0, _ := ret[0].(time.Time)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
|
||||
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1622,6 +1650,32 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers mocks base method.
|
||||
func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
|
||||
func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// ResolveAffectedPeers mocks base method.
|
||||
func (m *MockManager) ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ResolveAffectedPeers", ctx, s, accountID, change)
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ResolveAffectedPeers indicates an expected call of ResolveAffectedPeers.
|
||||
func (mr *MockManagerMockRecorder) ResolveAffectedPeers(ctx, s, accountID, change interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveAffectedPeers", reflect.TypeOf((*MockManager)(nil).ResolveAffectedPeers), ctx, s, accountID, change)
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks base method.
|
||||
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -3282,6 +3282,19 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
|
||||
@@ -240,6 +240,10 @@ const (
|
||||
AccountLocalMfaEnabled Activity = 123
|
||||
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
|
||||
AccountLocalMfaDisabled Activity = 124
|
||||
// UserExtendedPeerSession indicates that a user refreshed their peer's
|
||||
// SSO session deadline via ExtendAuthSession without re-establishing the
|
||||
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
|
||||
UserExtendedPeerSession Activity = 125
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
|
||||
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
|
||||
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
|
||||
|
||||
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
|
||||
|
||||
DomainAdded: {"Domain added", "domain.add"},
|
||||
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||
DomainValidated: {"Domain validated", "domain.validate"},
|
||||
|
||||
119
management/server/affected_peers_coverage_test.go
Normal file
119
management/server/affected_peers_coverage_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
|
||||
// dependency crossed with the change-type that can alter it, asserting the
|
||||
// resolver folds in exactly the peers whose map changes. A new dependency that
|
||||
// the resolver fails to walk should fail one of these rows; a new change-type
|
||||
// without a row is a coverage gap to add here.
|
||||
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
type row struct {
|
||||
name string
|
||||
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
|
||||
}
|
||||
|
||||
rows := []row{
|
||||
{
|
||||
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-change/explicit-policy refreshes source+routing",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-destinationresource/explicit-policy bridges to routing peer",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-change refreshes source+routing on its network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ResourceIDs: []string{s.resourceID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "network-change refreshes source+routing on that network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{NetworkIDs: []string{s.networkID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "posture-check-change refreshes source+routing of gated policy",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "cov-min-version",
|
||||
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty-change yields nothing",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
return affectedpeers.Change{}, nil, []string{s.sourcePeerID, s.routerPeerID, s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
t.Run(r.name, func(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
143
management/server/affected_peers_oldstate_test.go
Normal file
143
management/server/affected_peers_oldstate_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// An update spans an old and a new state. The affected set must be the UNION of
|
||||
// peers reachable before and after the change; resolving only against the final
|
||||
// state drops peers that were reachable but no longer are. These tests pin the
|
||||
// two paths where the old state is reachable only by the changed object's
|
||||
// previous references: detaching a resource group, and re-pointing a router peer.
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
|
||||
// a resource is reachable by a source group via two destination resource groups;
|
||||
// detaching one of them must still refresh that group's policy source peers, even
|
||||
// though the post-update resource no longer maps to it.
|
||||
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A second resource group + a second source group/peer that reaches the
|
||||
// resource only through that second group.
|
||||
const detachGroupID = "rs-detach-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID, detachGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Policy granting the second source group access via the detach group.
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
|
||||
settleAffectedUpdates(secondSrcCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Detaching the resource from detachGroup removes the second source's
|
||||
// access; that source peer must be refreshed even though the post-update
|
||||
// resource no longer maps to detachGroup.
|
||||
peerShouldReceiveUpdate(t, secondSrcCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
|
||||
// changing router.Peer within the same network must still refresh the OLD routing
|
||||
// peer, which loses its routing role.
|
||||
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
router := routers[0]
|
||||
oldRoutingPeer := router.Peer
|
||||
require.NotEmpty(t, oldRoutingPeer)
|
||||
|
||||
// A new peer to become the routing peer in place of the old one.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
|
||||
settleAffectedUpdates(oldCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// The old routing peer stops serving the resource and must be refreshed.
|
||||
peerShouldReceiveUpdate(t, oldCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
ID: router.ID,
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: newRoutingPeer.ID, // repoint within the same network
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
|
||||
}
|
||||
}
|
||||
251
management/server/affected_peers_property_test.go
Normal file
251
management/server/affected_peers_property_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// allPeerMaps computes the serialized per-peer network map for every peer in the
|
||||
// account, mirroring the controller's compute path so the property test compares
|
||||
// against real output.
|
||||
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := make(map[string]struct{}, len(account.Peers))
|
||||
for id := range account.Peers {
|
||||
validated[id] = struct{}{}
|
||||
}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
out := make(map[string]string, len(account.Peers))
|
||||
for peerID := range account.Peers {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
// Network.Serial is an account-global counter bumped on every change; it
|
||||
// is not a per-peer dependency, so normalize it out of the comparison.
|
||||
if nm.Network != nil {
|
||||
nm.Network.Serial = 0
|
||||
}
|
||||
out[peerID] = canonicalJSON(t, nm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// canonicalJSON marshals v and returns an order-insensitive string form: every
|
||||
// JSON array is sorted by the canonical form of its elements. The network map's
|
||||
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
|
||||
// a raw JSON compare would report spurious changes.
|
||||
func canonicalJSON(t *testing.T, v interface{}) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
var parsed interface{}
|
||||
require.NoError(t, json.Unmarshal(b, &parsed))
|
||||
canonicalized, err := json.Marshal(sortAny(parsed))
|
||||
require.NoError(t, err)
|
||||
return string(canonicalized)
|
||||
}
|
||||
|
||||
func sortAny(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
for i := range val {
|
||||
val[i] = sortAny(val[i])
|
||||
}
|
||||
sort.Slice(val, func(i, j int) bool {
|
||||
bi, _ := json.Marshal(val[i])
|
||||
bj, _ := json.Marshal(val[j])
|
||||
return string(bi) < string(bj)
|
||||
})
|
||||
return val
|
||||
case map[string]interface{}:
|
||||
for k := range val {
|
||||
val[k] = sortAny(val[k])
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// changedPeers returns the peer IDs whose serialized map differs between before
|
||||
// and after.
|
||||
func changedPeers(before, after map[string]string) []string {
|
||||
var changed []string
|
||||
for id, b := range before {
|
||||
a, ok := after[id]
|
||||
if !ok || a != b {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
for id := range after {
|
||||
if _, ok := before[id]; !ok {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
|
||||
// applies random changes, and asserts that the resolver's affected set is a
|
||||
// superset of the peers whose real network map actually changed. If the resolver
|
||||
// ever misses a dependency, a change will alter a peer's map without that peer
|
||||
// appearing in the affected set, failing here.
|
||||
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing peer->resource policy so the resource/router bridge is live.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extra peers and groups to give mutations room to move membership around.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
extraPeers := make([]string, 0, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
extraPeers = append(extraPeers, p.ID)
|
||||
}
|
||||
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
|
||||
for _, g := range extraGroups {
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
|
||||
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
|
||||
|
||||
for iter := 0; iter < 60; iter++ {
|
||||
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
|
||||
if apply == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
before := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
resolvedSet := make(map[string]struct{})
|
||||
resolve := func() {
|
||||
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||
for _, id := range s.manager.ResolveAffectedPeers(ctx, tx, s.accountID, change) {
|
||||
resolvedSet[id] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// Resolve on both sides of the mutation and union: removals are visible
|
||||
// only pre-apply (the leaving peer is still a member), additions only
|
||||
// post-apply (the joining peer is now a member). Production captures both
|
||||
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
|
||||
// models that without coupling the test to each path's ordering.
|
||||
resolve()
|
||||
changedIDs := change.ChangedPeerIDs
|
||||
apply()
|
||||
resolve()
|
||||
|
||||
after := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
// The explicitly-changed peer's own map refresh is the caller's
|
||||
// responsibility (the resolver returns the peers to propagate to), so it
|
||||
// is allowed to be absent from the resolved set.
|
||||
changedExplicitly := make(map[string]struct{}, len(changedIDs))
|
||||
for _, id := range changedIDs {
|
||||
changedExplicitly[id] = struct{}{}
|
||||
}
|
||||
|
||||
for _, id := range changedPeers(before, after) {
|
||||
if _, stillExists := after[id]; !stillExists {
|
||||
continue
|
||||
}
|
||||
if _, isExplicit := changedExplicitly[id]; isExplicit {
|
||||
continue
|
||||
}
|
||||
_, ok := resolvedSet[id]
|
||||
require.Truef(t, ok,
|
||||
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
|
||||
iter, id, maps.Keys(resolvedSet), change)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// randomMutation picks a random change, returns the Change to resolve and a
|
||||
// function that applies the underlying store mutation. apply is nil when the
|
||||
// drawn mutation is a no-op for the current state.
|
||||
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
switch rng.Intn(3) {
|
||||
case 0:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
peerID := allPeers[rng.Intn(len(allPeers))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if slicesContains(grp.Peers, peerID) {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
case 1:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if len(grp.Peers) == 0 {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
default:
|
||||
src := allGroups[rng.Intn(len(allGroups))]
|
||||
dst := allGroups[rng.Intn(len(allGroups))]
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{src},
|
||||
Destinations: []string{dst},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
}},
|
||||
}
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
func() {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func slicesContains(s []string, v string) bool {
|
||||
for _, x := range s {
|
||||
if x == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
125
management/server/affected_peers_querycount_test.go
Normal file
125
management/server/affected_peers_querycount_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// countingStore wraps a real store and counts the per-account collection loads
|
||||
// the resolver performs, so a test can assert each is read at most once and that
|
||||
// irrelevant collections are skipped entirely.
|
||||
type countingStore struct {
|
||||
store.Store
|
||||
mu sync.Mutex
|
||||
counts map[string]int
|
||||
}
|
||||
|
||||
func newCountingStore(s store.Store) *countingStore {
|
||||
return &countingStore{Store: s, counts: map[string]int{}}
|
||||
}
|
||||
|
||||
func (c *countingStore) bump(name string) {
|
||||
c.mu.Lock()
|
||||
c.counts[name]++
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *countingStore) count(name string) int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.counts[name]
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
|
||||
c.bump("policies")
|
||||
return c.Store.GetAccountPolicies(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
c.bump("routes")
|
||||
return c.Store.GetAccountRoutes(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
c.bump("nameservers")
|
||||
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
c.bump("dnssettings")
|
||||
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
c.bump("routers")
|
||||
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
c.bump("resources")
|
||||
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
c.bump("services")
|
||||
return c.Store.GetAccountServices(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
|
||||
// loads each per-account collection at most once per Resolve (memoization) even
|
||||
// on a change that drives every bridge, and skips the services table when the
|
||||
// account has no embedded proxy peers.
|
||||
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A group change that exercises policies, routers, resources and the bridge.
|
||||
affected, err := affectedpeers.Resolve(ctx, cs, s.accountID, affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
|
||||
|
||||
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
|
||||
assert.LessOrEqualf(t, cs.count(name), 1,
|
||||
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
|
||||
}
|
||||
assert.Equal(t, 0, cs.count("services"),
|
||||
"services must not be loaded when the account has no embedded proxy peers")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
|
||||
// no group/peer signal touches no per-account collections beyond what its inputs
|
||||
// require.
|
||||
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A bare network-id change drives only the router->source bridge: routers and
|
||||
// resources are needed, but routes/nameservers/dnssettings/services are not.
|
||||
_, err := affectedpeers.Resolve(ctx, cs, s.accountID, affectedpeers.Change{NetworkIDs: []string{s.networkID}})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
|
||||
}
|
||||
369
management/server/affected_peers_router_paths_test.go
Normal file
369
management/server/affected_peers_router_paths_test.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
|
||||
return s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, affectedpeers.Change{ChangedGroupIDs: changedGroupIDs})
|
||||
}
|
||||
|
||||
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
|
||||
return s.manager.resolveAffectedPeersForPeerChanges(ctx, s.manager.Store, s.accountID, changedPeerIDs)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"changing the source group must refresh the routing peer defined via router.PeerGroups")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
|
||||
assert.Contains(t, affected, s.sourcePeerID,
|
||||
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"a status change on a source peer must refresh the resource's routing peer that serves it")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_RoutingPeer_RefreshesPolicySources(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.routerPeerID})
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID,
|
||||
"a status change on the routing peer must refresh the source peers that route through it")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const memberOnlyGroupID = "rs-memberonly-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
|
||||
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
|
||||
|
||||
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DeleteGroup_LinkedGroupIsBlocked(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.manager.DeleteGroup(ctx, s.accountID, userID, s.sourceGroupID)
|
||||
require.Error(t, err, "deleting a policy-linked group must be blocked by validateDeleteGroup")
|
||||
|
||||
var linkErr *GroupLinkError
|
||||
require.ErrorAs(t, err, &linkErr, "expected a GroupLinkError")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const extraResourceGroupID = "rs-resource-grp-extra"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: extraResourceGroupID, Name: "rs-resource-extra",
|
||||
}))
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
|
||||
ID: s.resourceID,
|
||||
Type: types.ResourceTypeHost,
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"attaching a resource to a policy destination group must refresh the resource's routing peer")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func (s *routerScenario) resolvePostureCheckAffected(ctx context.Context, postureCheckID string) []string {
|
||||
return s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, affectedpeers.Change{PostureCheckIDs: []string{postureCheckID}})
|
||||
}
|
||||
|
||||
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
|
||||
t.Helper()
|
||||
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
return check.ID
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PostureCheckChange_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
checkID := s.createPostureCheckGatedPolicy(t, ctx)
|
||||
|
||||
affected := s.resolvePostureCheckAffected(ctx, checkID)
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "policy source peer must be affected by a posture-check change")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"a posture check gating a peer->resource policy must refresh the resource's routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
checkID := s.createPostureCheckGatedPolicy(t, ctx)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
ID: checkID,
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: editing a posture check did not refresh source + routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, routersManager, _ := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: disabledRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9000,
|
||||
Enabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
|
||||
settleAffectedUpdates(disabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "groupiso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-group change for another resource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "peeriso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-peer change for another resource")
|
||||
}
|
||||
791
management/server/affected_peers_router_test.go
Normal file
791
management/server/affected_peers_router_test.go
Normal file
@@ -0,0 +1,791 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// routerScenario captures the topology from the bug report:
|
||||
//
|
||||
// network ── router (routing peer) ── resource (in resourceGroup)
|
||||
// independent peer ──(policy: source -> resource)──> resource
|
||||
//
|
||||
// The routing peer must be refreshed when a policy grants a source peer access
|
||||
// to the resource, because the network map connects the source peer to the
|
||||
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
|
||||
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
|
||||
// group, so static group/peer resolution alone cannot find it.
|
||||
type routerScenario struct {
|
||||
manager *DefaultAccountManager
|
||||
updateManager *update_channel.PeersUpdateManager
|
||||
accountID string
|
||||
networkID string
|
||||
|
||||
sourcePeerID string // independent peer that the policy grants access from
|
||||
sourceGroupID string // group containing the source peer
|
||||
|
||||
routerPeerID string // peer acting as the routing peer (direct router.Peer)
|
||||
routerGroupPeerID string // peer that is a member of routerPeerGroup
|
||||
routerPeerGroupID string // group used for router.PeerGroups
|
||||
|
||||
resourceID string // network resource
|
||||
resourceGroupID string // group whose member is the resource (no peers)
|
||||
|
||||
unrelatedPeerID string // peer in no relevant entity
|
||||
}
|
||||
|
||||
// setupRouterScenario builds the topology above with the default policy removed
|
||||
// and channels NOT yet created, so callers control exactly when updates can flow.
|
||||
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
|
||||
t.Helper()
|
||||
|
||||
manager, updateManager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := createAccount(manager, "router_scenario", userID, "")
|
||||
require.NoError(t, err)
|
||||
accountID := account.Id
|
||||
|
||||
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
|
||||
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, p := range policies {
|
||||
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
|
||||
const (
|
||||
sourceGroupID = "rs-source-grp"
|
||||
routerPeerGroupID = "rs-router-grp"
|
||||
resourceGroupID = "rs-resource-grp"
|
||||
)
|
||||
|
||||
for _, g := range []*types.Group{
|
||||
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
|
||||
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
|
||||
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
|
||||
} {
|
||||
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network",
|
||||
AccountID: accountID,
|
||||
Name: "rs-network",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
router := &routerTypes.NetworkRouter{
|
||||
ID: "rs-router",
|
||||
NetworkID: network.ID,
|
||||
AccountID: accountID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
if directRouterPeer {
|
||||
router.Peer = routerPeer.ID
|
||||
} else {
|
||||
router.PeerGroups = []string{routerPeerGroupID}
|
||||
}
|
||||
_, err = routersManager.CreateRouter(ctx, userID, router)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &routerScenario{
|
||||
manager: manager,
|
||||
updateManager: updateManager,
|
||||
accountID: accountID,
|
||||
networkID: network.ID,
|
||||
sourcePeerID: sourcePeer.ID,
|
||||
sourceGroupID: sourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
routerGroupPeerID: routerGroupPeer.ID,
|
||||
routerPeerGroupID: routerPeerGroupID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
unrelatedPeerID: unrelatedPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicy builds a policy granting the source group access to the
|
||||
// resource, referencing the resource by its group in the rule destination.
|
||||
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-group",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
Destinations: []string{resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicyByResource builds a policy referencing the resource
|
||||
// directly via DestinationResource rather than its group.
|
||||
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Resolution-level tests: collectPolicyAffectedGroupsAndPeers + resolvePeerIDs.
|
||||
//
|
||||
// These isolate the resolver from the controller and assert directly on the set
|
||||
// of peer IDs the policy path would refresh. They make the gap explicit: the
|
||||
// routing peer is expected to be in the affected set but is not.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
|
||||
// peers for the given policy.
|
||||
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
|
||||
return s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, affectedpeers.Change{Policies: []*types.Policy{policy}})
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResourceByGroup_IncludesSourcePeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// The source peer is in the source group and must always be present.
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResourceByGroup_IncludesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// BUG: the direct routing peer serves the resource's subnet to the source
|
||||
// peer, so it must be refreshed when the policy is created. The policy path
|
||||
// only resolves the literal rule groups (source group + resource group);
|
||||
// the resource group has no peer members and the router peer is reachable
|
||||
// only through the network, so it is dropped.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResourceByGroup_IncludesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Same gap when the router is defined via PeerGroups instead of a direct peer.
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (router.PeerGroups member) serving the resource must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResourceByDestinationResource_IncludesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// When the resource is referenced via DestinationResource, RuleGroups()
|
||||
// returns only the source group and the resource ID is not a peer, so
|
||||
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
|
||||
// all. The routing peer is dropped here too.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResourceByDestinationResource_IncludesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResourceWithSourceResourcePeer_IncludesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// Source expressed as a direct peer (SourceResource), destination as resource group.
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "sourceResource-peer-to-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
|
||||
// handles SourceResource peers), but the routing peer is still missing.
|
||||
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Guard against an over-broad fix: a peer in no relevant entity must never
|
||||
// be pulled in.
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Control: the resource/router managers DO bridge resource-group -> router.
|
||||
// These document the existing (correct) behaviour on the resource side and
|
||||
// highlight the asymmetry with the policy side above.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing policy grants the source group access to the resource.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Drive an update through the resource manager and assert the routing peer
|
||||
// is among the affected set by observing the channel. This path walks
|
||||
// policies whose destinations reference the resource's groups, folds in the
|
||||
// source groups, and loads the network's routers, so it reaches both the
|
||||
// source peer and the routing peer.
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// End-to-end: reproduce the reported symptom through SavePolicy with channels.
|
||||
//
|
||||
// Creating the peer->resource policy must wake the routing peer. These fail on
|
||||
// the current code because the policy path never resolves the routing peer.
|
||||
//
|
||||
// IMPORTANT: setup (CreateNetwork/CreateResource/CreateRouter) fires async
|
||||
// `go UpdateAffectedPeers` goroutines. Channels are opened and then drained via
|
||||
// settleAffectedUpdates before the action under test, so the assertion only
|
||||
// observes updates caused by that action and not stragglers from setup. The
|
||||
// resolution-level tests above are the timing-free, authoritative proof; these
|
||||
// reproduce the operator-visible symptom.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
|
||||
// every given channel so subsequent assertions start from a clean slate.
|
||||
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
for _, ch := range chans {
|
||||
drainPeerUpdates(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicyToResource_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh) // FAILS today: routing peer not resolved
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicyToResource_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh) // FAILS today
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicyByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh) // FAILS today
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeletePolicyToResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh) // FAILS today: deleting the policy must also refresh the router
|
||||
close(done)
|
||||
}()
|
||||
|
||||
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
|
||||
return resourcesManager, routersManager, networksManager
|
||||
}
|
||||
|
||||
type secondTopology struct {
|
||||
networkID string
|
||||
resourceID string
|
||||
resourceGroupID string
|
||||
routerPeerID string
|
||||
}
|
||||
|
||||
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
resourcesManager, routersManager, networksManager := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
resourceGroupID := "rs-resource-grp-" + suffix
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: resourceGroupID, Name: "rs-resource-" + suffix,
|
||||
}))
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network-" + suffix,
|
||||
AccountID: s.accountID,
|
||||
Name: "rs-network-" + suffix,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: s.accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host-" + suffix,
|
||||
Address: "10.40.50.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: network.ID,
|
||||
AccountID: s.accountID,
|
||||
Peer: routerPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return secondTopology{
|
||||
networkID: network.ID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicyRepointResource_RefreshesBothRoutingPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "b")
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, second.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerACh, routerBCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerACh)
|
||||
peerShouldReceiveUpdate(t, routerBCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Destinations = []string{second.resourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicyAddSourceGroup_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(newSrcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, newSrcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicyByDestinationResource_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_IncludesAllRoutingPeersOnNetwork(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: secondRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9998,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_DisabledRouterStillAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
routers[0].Enabled = false
|
||||
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_DisabledResourceStillAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
|
||||
require.NoError(t, err)
|
||||
res.Enabled = false
|
||||
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_DisabledRuleStillAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.Rules[0].Enabled = false
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_MultiRulePolicy_IncludesAllRoutingPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "c")
|
||||
ctx := context.Background()
|
||||
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "multi-rule-two-resources",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{second.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
|
||||
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "d")
|
||||
ctx := context.Background()
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a policy that does not target its resource")
|
||||
}
|
||||
1921
management/server/affected_peers_test.go
Normal file
1921
management/server/affected_peers_test.go
Normal file
File diff suppressed because it is too large
Load Diff
669
management/server/affectedpeers/resolver.go
Normal file
669
management/server/affectedpeers/resolver.go
Normal file
@@ -0,0 +1,669 @@
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// Change describes what changed in an account. The resolver never consults the
|
||||
// Enabled flag of any object: toggling Enabled is itself an observable change.
|
||||
type Change struct {
|
||||
ChangedGroupIDs []string
|
||||
ChangedPeerIDs []string
|
||||
Policies []*types.Policy
|
||||
Routes []*route.Route
|
||||
PostureCheckIDs []string
|
||||
ResourceIDs []string
|
||||
NetworkIDs []string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
return len(c.ChangedGroupIDs) == 0 &&
|
||||
len(c.ChangedPeerIDs) == 0 &&
|
||||
len(c.Policies) == 0 &&
|
||||
len(c.Routes) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.ResourceIDs) == 0 &&
|
||||
len(c.NetworkIDs) == 0
|
||||
}
|
||||
|
||||
// Resolve returns the deduplicated peer IDs whose network map may have changed by
|
||||
// the given Change. Safe to call inside or after a transaction.
|
||||
//
|
||||
// At trace level it logs the full reasoning — which inputs drove which graph
|
||||
// walks to which groups/peers, including the resource<->router bridge hops — so a
|
||||
// miscalculation can be diagnosed from the logs alone.
|
||||
func Resolve(ctx context.Context, s store.Store, accountID string, c Change) ([]string, error) {
|
||||
if c.isEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
r := newResolver(ctx, s, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers resolve start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d postureChecks=%v resources=%v networks=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), c.PostureCheckIDs, c.ResourceIDs, c.NetworkIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
|
||||
// Collect returns the affected group IDs and direct peer IDs without expanding
|
||||
// groups to members. For tests asserting on the intermediate sets; use Resolve otherwise.
|
||||
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
|
||||
if c.isEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
r := newResolver(ctx, s, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, s store.Store, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
store: s,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
resourceIDs: toSet(c.ResourceIDs),
|
||||
networkIDs: toSet(c.NetworkIDs),
|
||||
}
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *resolver) walk() {
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
r.collectFromDNSSettings()
|
||||
r.collectFromNetworkRouters()
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectResourceRouterBridge()
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
ctx context.Context
|
||||
store store.Store
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
resourceIDs map[string]struct{}
|
||||
networkIDs map[string]struct{}
|
||||
|
||||
// Memoized per-account collections: each is loaded from the store at most
|
||||
// once per Resolve and only when a walker actually needs it.
|
||||
cachedPolicies []*types.Policy
|
||||
policiesLoaded bool
|
||||
cachedResources []*resourceTypes.NetworkResource
|
||||
resourcesLoaded bool
|
||||
cachedRouters []*routerTypes.NetworkRouter
|
||||
routersLoaded bool
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy {
|
||||
if r.policiesLoaded {
|
||||
return r.cachedPolicies
|
||||
}
|
||||
r.policiesLoaded = true
|
||||
policies, err := r.store.GetAccountPolicies(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get policies for affected peers resolution: %v", err)
|
||||
return nil
|
||||
}
|
||||
r.cachedPolicies = policies
|
||||
return r.cachedPolicies
|
||||
}
|
||||
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
|
||||
if r.resourcesLoaded {
|
||||
return r.cachedResources
|
||||
}
|
||||
r.resourcesLoaded = true
|
||||
resources, err := r.store.GetNetworkResourcesByAccountID(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get network resources for affected peers resolution: %v", err)
|
||||
return nil
|
||||
}
|
||||
r.cachedResources = resources
|
||||
return r.cachedResources
|
||||
}
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
|
||||
if r.routersLoaded {
|
||||
return r.cachedRouters
|
||||
}
|
||||
r.routersLoaded = true
|
||||
routers, err := r.store.GetNetworkRoutersByAccountID(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get network routers for affected peers resolution: %v", err)
|
||||
return nil
|
||||
}
|
||||
r.cachedRouters = routers
|
||||
return r.cachedRouters
|
||||
}
|
||||
|
||||
func (r *resolver) expand() ([]string, error) {
|
||||
groupIDs := setToSlice(r.groupSet)
|
||||
var peerIDs []string
|
||||
if len(groupIDs) > 0 {
|
||||
ids, err := r.store.GetPeerIDsByGroups(r.ctx, r.accountID, groupIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peerIDs = ids
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers resolve expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, groupIDs, len(peerIDs), setToSlice(r.peerSet))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.peerSet {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers resolve done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
|
||||
return peerIDs, nil
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
routes, err := r.store.GetAccountRoutes(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get routes for affected peers resolution: %v", err)
|
||||
return
|
||||
}
|
||||
for _, rt := range routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
nsGroups, err := r.store.GetAccountNameServerGroups(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get nameserver groups for affected peers resolution: %v", err)
|
||||
return
|
||||
}
|
||||
for _, ns := range nsGroups {
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
dnsSettings, err := r.store.GetAccountDNSSettings(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get DNS settings for affected peers resolution: %v", err)
|
||||
return
|
||||
}
|
||||
for _, gID := range dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromProxyServices() {
|
||||
services, proxyByCluster, ok := r.loadProxyServiceContext()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.peerSet[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) loadProxyServiceContext() ([]*rpservice.Service, map[string][]string, bool) {
|
||||
// Embedded proxy peers are the prerequisite for any synthesized proxy policy.
|
||||
// Probe that first (a narrow, indexed lookup) and skip the services table load
|
||||
// entirely when the account has no embedded proxy peers.
|
||||
proxyByCluster, err := r.store.GetEmbeddedProxyPeerIDsByCluster(r.ctx, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get embedded proxy peers for affected peers resolution: %v", err)
|
||||
return nil, nil, false
|
||||
}
|
||||
if len(proxyByCluster) == 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
services, err := r.store.GetAccountServices(r.ctx, store.LockingStrengthNone, r.accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get services for affected peers resolution: %v", err)
|
||||
return nil, nil, false
|
||||
}
|
||||
if len(services) == 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
return services, proxyByCluster, true
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
ids, err := r.store.GetPeerIDsByGroups(r.ctx, r.accountID, setToSlice(r.changedGroupSet))
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to expand changed groups to peers for service resolution: %v", err)
|
||||
return r.changedPeerSet
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// collectResourceRouterBridge folds in the routing peers serving the resources
|
||||
// targeted by matched/explicit policies (source -> router), and the source peers
|
||||
// of policies serving resources on the affected networks (router -> source). The
|
||||
// routing peer is reachable only through resource -> network -> router, never
|
||||
// through the policy's own groups, so it must be collected here.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
for id := range r.resourceIDs {
|
||||
resourceIDs[id] = struct{}{}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
for _, router := range r.networkRouters() {
|
||||
if _, ok := networkIDs[router.NetworkID]; !ok {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
|
||||
networkIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := resourceIDs[resource.ID]; ok {
|
||||
networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return networkIDs
|
||||
}
|
||||
|
||||
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
destGroupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(destGroupSet) == 0 {
|
||||
return false
|
||||
}
|
||||
groups, err := r.store.GetGroupsByIDs(r.ctx, store.LockingStrengthNone, r.accountID, setToSlice(destGroupSet))
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get destination groups for router policy bridge: %v", err)
|
||||
return false
|
||||
}
|
||||
for _, group := range groups {
|
||||
for _, res := range group.Resources {
|
||||
if isInSet(res.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
|
||||
r.addGroupResourceIDs(destGroupSet, resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
// collectPolicyDestinations adds each rule's direct destination resource IDs to
|
||||
// resourceIDs and returns the set of destination group IDs referenced.
|
||||
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, policy := range policies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(destGroupSet, rule.Destinations)
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return destGroupSet
|
||||
}
|
||||
|
||||
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
|
||||
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
|
||||
if len(groupIDs) == 0 {
|
||||
return
|
||||
}
|
||||
groups, err := r.store.GetGroupsByIDs(r.ctx, store.LockingStrengthNone, r.accountID, setToSlice(groupIDs))
|
||||
if err != nil {
|
||||
log.WithContext(r.ctx).Errorf("failed to get destination groups for resource router bridge: %v", err)
|
||||
return
|
||||
}
|
||||
for _, group := range groups {
|
||||
for _, res := range group.Resources {
|
||||
if res.ID != "" {
|
||||
resourceIDs[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
|
||||
if res.Type != types.ResourceTypePeer || res.ID == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := set[res.ID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
|
||||
for _, pid := range proxyPeers {
|
||||
if _, ok := changedPeers[pid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isInSet(id string, set map[string]struct{}) bool {
|
||||
_, ok := set[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func addAll(set map[string]struct{}, slices ...[]string) {
|
||||
for _, s := range slices {
|
||||
for _, id := range s {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toSet(ids []string) map[string]struct{} {
|
||||
set := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func setToSlice(set map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(set))
|
||||
for id := range set {
|
||||
s = append(s, id)
|
||||
}
|
||||
return s
|
||||
}
|
||||
138
management/server/affectedpeers/resolver_test.go
Normal file
138
management/server/affectedpeers/resolver_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
}
|
||||
return groups, peers
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
|
||||
assert.Empty(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{"g2"},
|
||||
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
|
||||
groups, peers := policyGroupsAndPeers(nil)
|
||||
assert.Nil(t, groups)
|
||||
assert.Nil(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
|
||||
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
|
||||
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
|
||||
groups, _ := policyGroupsAndPeers(updated, old)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.True(t, Change{}.isEmpty())
|
||||
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
|
||||
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
|
||||
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
|
||||
assert.False(t, Change{ResourceIDs: []string{"r"}}.isEmpty())
|
||||
assert.False(t, Change{NetworkIDs: []string{"n"}}.isEmpty())
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
|
||||
groupSet := map[string]struct{}{}
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicySources(policy, groupSet, peerSet)
|
||||
|
||||
assert.Contains(t, groupSet, "g1")
|
||||
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
}
|
||||
@@ -1,10 +1,27 @@
|
||||
package context
|
||||
|
||||
import "github.com/netbirdio/netbird/shared/context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/shared/context"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestIDKey = context.RequestIDKey
|
||||
AccountIDKey = context.AccountIDKey
|
||||
UserIDKey = context.UserIDKey
|
||||
PeerIDKey = context.PeerIDKey
|
||||
RequestIDKey = nbcontext.RequestIDKey
|
||||
AccountIDKey = nbcontext.AccountIDKey
|
||||
RoleKey = nbcontext.RoleKey
|
||||
UserIDKey = nbcontext.UserIDKey
|
||||
PeerIDKey = nbcontext.PeerIDKey
|
||||
)
|
||||
|
||||
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
|
||||
func RoleFromContext(ctx context.Context) (string, bool) {
|
||||
role, ok := ctx.Value(RoleKey).(string)
|
||||
return role, ok
|
||||
}
|
||||
|
||||
// WithRole returns a new context carrying the given role.
|
||||
func WithRole(ctx context.Context, role string) context.Context {
|
||||
//nolint
|
||||
return context.WithValue(ctx, RoleKey, role)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
@@ -63,11 +63,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
@@ -75,6 +70,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
allGroups := slices.Concat(addedGroups, removedGroups)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -85,8 +83,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("SaveDNSSettings: updating %d affected peers: %v", len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("SaveDNSSettings: no affected peers")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -133,20 +134,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
|
||||
@@ -23,7 +23,7 @@ func isEnabled() bool {
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -32,7 +33,7 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
allowed, _, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -70,7 +71,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
||||
|
||||
// CreateGroup object of the peers
|
||||
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -79,7 +80,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -91,11 +92,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -106,6 +102,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -116,8 +114,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("CreateGroup %s: updating %d affected peers: %v", newGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("CreateGroup %s: no affected peers", newGroup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -125,7 +126,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
|
||||
// UpdateGroup object of the peers
|
||||
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -134,7 +135,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -151,22 +152,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||
}
|
||||
|
||||
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
if err = syncGroupMembership(ctx, transaction, accountID, newGroup.ID, util.Difference(newGroup.Peers, oldGroup.Peers), peersToRemove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -178,6 +165,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}, ChangedPeerIDs: peersToRemove})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -188,19 +177,37 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("UpdateGroup %s: updating %d affected peers: %v", newGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("UpdateGroup %s: no affected peers", newGroup.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncGroupMembership applies the peer membership delta for a group within a transaction.
|
||||
func syncGroupMembership(ctx context.Context, transaction store.Store, accountID, groupID string, peersToAdd, peersToRemove []string) error {
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -209,7 +216,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
@@ -247,17 +253,16 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
affectedPeerIDs := am.ResolveAffectedPeers(ctx, am.Store, accountID, affectedpeers.Change{ChangedGroupIDs: groupIDs})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("CreateGroups %v: updating %d affected peers: %v", groupIDs, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("CreateGroups %v: no affected peers", groupIDs)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -268,7 +273,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -277,7 +282,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
@@ -295,17 +299,16 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
affectedPeerIDs := am.ResolveAffectedPeers(ctx, am.Store, accountID, affectedpeers.Change{ChangedGroupIDs: groupIDs})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("UpdateGroups %v: updating %d affected peers: %v", groupIDs, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("UpdateGroups %v: no affected peers", groupIDs)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -427,7 +430,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -438,6 +441,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
var affectedPeerIDs []string
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -445,26 +449,17 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
deletedGroups, allErrors = collectDeletableGroups(ctx, transaction, accountID, userID, groupIDs, extraSettings.FlowGroups)
|
||||
for _, group := range deletedGroups {
|
||||
groupIDsToDelete = append(groupIDsToDelete, group.ID)
|
||||
}
|
||||
|
||||
if len(groupIDsToDelete) == 0 {
|
||||
return allErrors
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: groupIDsToDelete})
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,20 +478,42 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeleteGroups %v: updating %d affected peers: %v", groupIDsToDelete, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeleteGroups %v: no affected peers", groupIDsToDelete)
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// collectDeletableGroups loads and validates each group for deletion, returning
|
||||
// the groups that may be deleted and the joined validation errors for the rest.
|
||||
func collectDeletableGroups(ctx context.Context, transaction store.Store, accountID, userID string, groupIDs, flowGroups []string) ([]*types.Group, error) {
|
||||
var deletable []*types.Group
|
||||
var allErrors error
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, flowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
deletable = append(deletable, group)
|
||||
}
|
||||
return deletable, allErrors
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -505,14 +522,19 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("GroupAddPeer group=%s peer=%s: updating %d affected peers: %v", groupID, peerID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("GroupAddPeer group=%s peer=%s: no affected peers", groupID, peerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -521,7 +543,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -534,23 +556,23 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("GroupAddResource group=%s resource=%s: updating %d affected peers: %v", groupID, resource.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("GroupAddResource group=%s resource=%s: no affected peers", groupID, resource.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -558,14 +580,12 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Resolve before removing, so the peer being removed is still included
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
|
||||
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return err
|
||||
@@ -581,8 +601,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("GroupDeletePeer group=%s peer=%s: updating %d affected peers: %v", groupID, peerID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("GroupDeletePeer group=%s peer=%s: no affected peers", groupID, peerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -591,7 +614,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -604,23 +627,23 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("GroupDeleteResource group=%s resource=%s: updating %d affected peers: %v", groupID, resource.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("GroupDeleteResource group=%s resource=%s: no affected peers", groupID, resource.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -832,49 +855,103 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store,
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
// It fetches each collection once and checks all groupIDs against them in memory.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
groupSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() || group.HasResources() {
|
||||
for _, ns := range nameServerGroups {
|
||||
if anyInSet(ns.Groups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, r := range routes {
|
||||
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, router := range routers {
|
||||
if anyInSet(router.PeerGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
|
||||
}
|
||||
|
||||
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
@@ -32,6 +34,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||
@@ -46,6 +49,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
@@ -55,7 +59,7 @@ import (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, instanceManager nbinstance.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, authMiddleware mux.MiddlewareFunc) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -76,11 +80,32 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
if rateLimiter == nil {
|
||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
authManager,
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
isValidChildAccount,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware)
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
|
||||
@@ -405,48 +405,48 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
allowed, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
util.WriteError(ctx, status.NewPermissionValidationError(err), w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
account, err := h.accountManager.GetAccountByID(ctx, accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed && !userAuth.IsChild {
|
||||
if account.Settings.RegularUsersViewBlocked {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
util.WriteJSONObject(ctx, w, []api.AccessiblePeer{})
|
||||
return
|
||||
}
|
||||
|
||||
peer, ok := account.Peers[peerID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||
util.WriteError(ctx, status.Errorf(status.NotFound, "peer not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.UserID != user.Id {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
util.WriteJSONObject(ctx, w, []api.AccessiblePeer{})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
util.WriteJSONObject(ctx, w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -116,15 +116,15 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
|
||||
ctrl2 := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(context.Background(), nil).AnyTimes()
|
||||
permissionsManager.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
|
||||
DoAndReturn(func(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
|
||||
DoAndReturn(func(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, context.Context, error) {
|
||||
user, ok := account.Users[userID]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("user not found")
|
||||
return false, ctx, fmt.Errorf("user not found")
|
||||
}
|
||||
return user.HasAdminPower() || user.IsServiceUser, nil
|
||||
return user.HasAdminPower() || user.IsServiceUser, ctx, nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
|
||||
permissionsManagerMock.
|
||||
EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), modules.Policies, operations.Read).
|
||||
Return(true, nil).
|
||||
Return(true, context.Background(), nil).
|
||||
AnyTimes()
|
||||
|
||||
return &geolocationsHandler{
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
@@ -26,8 +25,7 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
// jwtTokenCtxKey carries the parsed JWT token.
|
||||
type jwtTokenCtxKey struct{}
|
||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
@@ -37,6 +35,7 @@ type AuthMiddleware struct {
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
isValidChildAccount IsValidChildAccountFunc
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -47,6 +46,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
isValidChildAccount IsValidChildAccountFunc,
|
||||
) *AuthMiddleware {
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
@@ -64,18 +64,12 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
isValidChildAccount: isValidChildAccount,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler composes the full authentication chain by wrapping the given
|
||||
// handler with ValidationHandler followed by AccountAccessHandler.
|
||||
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
|
||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return m.ValidationHandler(m.AccountAccessHandler(h))
|
||||
}
|
||||
|
||||
// ValidationHandler authenticates the caller via JWT or PAT and stores the
|
||||
// resulting UserAuth in the request context. It performs no account-level work.
|
||||
func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||
return
|
||||
@@ -92,14 +86,14 @@ func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
if err := m.validateJWT(r, authHeader); err != nil {
|
||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
if err := m.validatePAT(r, authHeader); err != nil {
|
||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
@@ -116,55 +110,66 @@ func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// AccountAccessHandler runs post-validation access checks for JWT-authenticated
|
||||
// requests. PAT requests pass through unchanged.
|
||||
func (m *AuthMiddleware) AccountAccessHandler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if userAuth.IsPAT {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
validatedToken, _ := r.Context().Value(jwtTokenCtxKey{}).(*jwt.Token)
|
||||
|
||||
if err := m.applyAccountAccess(r, userAuth, validatedToken); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error applying JWT account access: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) validateJWT(r *http.Request, authHeaderParts []string) error {
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(r.Context(), token)
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
}
|
||||
|
||||
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
|
||||
// Available as userAuth.Email
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
*r = *r.WithContext(context.WithValue(r.Context(), jwtTokenCtxKey{}, validatedToken))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string) error {
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
@@ -187,7 +192,8 @@ func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string)
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
if err := m.authManager.MarkPATUsed(ctx, pat.ID); err != nil {
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -199,40 +205,11 @@ func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string)
|
||||
IsPAT: true,
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyAccountAccess executes account-level checks for an authenticated JWT
|
||||
// user: ensures the account exists, verifies access via JWT groups, syncs
|
||||
// groups, and fetches the user record.
|
||||
func (m *AuthMiddleware) applyAccountAccess(r *http.Request, userAuth auth.UserAuth, validatedToken *jwt.Token) error {
|
||||
ctx := r.Context()
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.syncUserJWTGroups(ctx, userAuth); err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
if _, err := m.getUserFromUserAuth(ctx, userAuth); err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return err
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
|
||||
@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -40,7 +40,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
@@ -137,12 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
rateLimiter := middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -271,12 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
rateLimiter := middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.Identi
|
||||
|
||||
// GetIdentityProviders returns all identity providers for an account
|
||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
ok, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accou
|
||||
|
||||
// GetIdentityProvider returns a specific identity provider by ID
|
||||
func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
ok, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accoun
|
||||
|
||||
// CreateIdentityProvider creates a new identity provider
|
||||
func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Create)
|
||||
ok, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
|
||||
|
||||
// UpdateIdentityProvider updates an existing identity provider
|
||||
func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Update)
|
||||
ok, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
|
||||
|
||||
// DeleteIdentityProvider deletes an identity provider
|
||||
func (am *DefaultAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Delete)
|
||||
ok, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -98,6 +99,7 @@ type MockAccountManager struct {
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||
@@ -131,6 +133,9 @@ type MockAccountManager struct {
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
|
||||
BufferUpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason)
|
||||
ResolveAffectedPeersFunc func(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
|
||||
@@ -208,6 +213,25 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
|
||||
if am.UpdateAffectedPeersFunc != nil {
|
||||
am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string {
|
||||
if am.ResolveAffectedPeersFunc != nil {
|
||||
return am.ResolveAffectedPeersFunc(ctx, s, accountID, change)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
|
||||
if am.BufferUpdateAffectedPeersFunc != nil {
|
||||
am.BufferUpdateAffectedPeersFunc(ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
if am.BufferUpdateAccountPeersFunc != nil {
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)
|
||||
@@ -860,6 +884,14 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
}
|
||||
|
||||
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
|
||||
func (am *MockAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||
if am.ExtendPeerSessionFunc != nil {
|
||||
return am.ExtendPeerSessionFunc(ctx, peerPubKey, userID)
|
||||
}
|
||||
return time.Time{}, status.Errorf(codes.Unimplemented, "method ExtendPeerSession is not implemented")
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -23,7 +25,7 @@ var errInvalidDomainName = errors.New("invalid domain name")
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -36,7 +38,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -57,22 +59,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
SearchDomainsEnabled: searchDomainEnabled,
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -81,8 +80,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("CreateNameServerGroup %s: updating %d affected peers: %v", newNSGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("CreateNameServerGroup %s: no affected peers", newNSGroup.ID)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@@ -94,7 +96,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -102,7 +104,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
|
||||
@@ -115,15 +117,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -132,8 +132,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("SaveNameServerGroup %s: updating %d affected peers: %v", nsGroupToSave.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("SaveNameServerGroup %s: no affected peers", nsGroupToSave.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -141,7 +144,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -150,7 +153,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
}
|
||||
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
||||
@@ -158,10 +161,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil)
|
||||
|
||||
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
|
||||
return err
|
||||
@@ -175,8 +175,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeleteNameServerGroup %s: updating %d affected peers: %v", nsGroupID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeleteNameServerGroup %s: no affected peers", nsGroupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -184,7 +187,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -224,24 +227,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return validateGroups(nameserverGroup.Groups, groups)
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||
if !primary && len(domains) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -15,7 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -49,7 +50,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, resou
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -82,7 +83,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -94,7 +95,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, modules.Networks, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -113,7 +114,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -127,7 +128,10 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{NetworkIDs: []string{networkID}})
|
||||
|
||||
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resources in network: %w", err)
|
||||
@@ -141,12 +145,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
eventsToStore = append(eventsToStore, event...)
|
||||
}
|
||||
|
||||
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get routers in network: %w", err)
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
for _, router := range netRouters {
|
||||
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete router: %w", err)
|
||||
@@ -178,7 +182,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeleteNetwork %s: updating %d affected peers: %v", networkID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeleteNetwork %s: no affected peers", networkID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -54,7 +55,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, group
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -66,7 +67,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -78,7 +79,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -100,7 +101,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -114,45 +115,11 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||
}
|
||||
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
|
||||
event := func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
|
||||
res := nbtypes.Resource{
|
||||
ID: resource.ID,
|
||||
Type: nbtypes.ResourceType(resource.Type.String()),
|
||||
}
|
||||
for _, groupID := range resource.GroupIDs {
|
||||
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add resource to group: %w", err)
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
var txErr error
|
||||
eventsToStore, affectedPeerIDs, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource)
|
||||
return txErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create network resource: %w", err)
|
||||
@@ -162,13 +129,59 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("CreateResource %s: updating %d affected peers: %v", resource.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("CreateResource %s: no affected peers", resource.ID)
|
||||
}
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource) ([]func(), []string, error) {
|
||||
_, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||
}
|
||||
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
res := nbtypes.Resource{
|
||||
ID: resource.ID,
|
||||
Type: nbtypes.ResourceType(resource.Type.String()),
|
||||
}
|
||||
for _, groupID := range resource.GroupIDs {
|
||||
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to add resource to group: %w", err)
|
||||
}
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
affectedPeerIDs := m.accountManager.ResolveAffectedPeers(ctx, transaction, resource.AccountID, affectedpeers.Change{ResourceIDs: []string{resource.ID}})
|
||||
|
||||
return eventsToStore, affectedPeerIDs, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -189,7 +202,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, modules.Networks, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -207,6 +220,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
resource.Prefix = prefix
|
||||
|
||||
var eventsToStore []func()
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
@@ -232,6 +246,15 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old resource groups: %w", err)
|
||||
}
|
||||
var oldGroupIDs []string
|
||||
for _, g := range oldGroups {
|
||||
oldGroupIDs = append(oldGroupIDs, g.ID)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
@@ -247,6 +270,13 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
// Pass both old and new resource group IDs so policies that targeted the
|
||||
// resource via a now-detached group still refresh their source peers.
|
||||
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, resource.AccountID, affectedpeers.Change{
|
||||
ResourceIDs: []string{resource.ID},
|
||||
ChangedGroupIDs: append(oldGroupIDs, resource.GroupIDs...),
|
||||
})
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
@@ -270,7 +300,12 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
}()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("UpdateResource %s: updating %d affected peers: %v", resource.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("UpdateResource %s: no affected peers", resource.ID)
|
||||
}
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
@@ -314,7 +349,7 @@ func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction stor
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -331,7 +366,10 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
}
|
||||
|
||||
var events []func()
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ResourceIDs: []string{resourceID}})
|
||||
|
||||
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
@@ -352,7 +390,12 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeleteResource %s: updating %d affected peers: %v", resourceID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeleteResource %s: no affected peers", resourceID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,16 +6,17 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -47,7 +48,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -59,7 +60,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Create)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -90,6 +91,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
}
|
||||
|
||||
var network *networkTypes.Network
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
@@ -112,6 +114,8 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, router.AccountID, affectedpeers.Change{NetworkIDs: []string{router.NetworkID}})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -120,13 +124,18 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("CreateRouter %s: updating %d affected peers: %v", router.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("CreateRouter %s: no affected peers", router.ID)
|
||||
}
|
||||
|
||||
return router, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -147,7 +156,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Update)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, modules.Networks, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -156,36 +165,11 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
}
|
||||
|
||||
var network *networkTypes.Network
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
err = transaction.UpdateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
var txErr error
|
||||
network, affectedPeerIDs, txErr = m.updateRouterInTransaction(ctx, transaction, router)
|
||||
return txErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -193,13 +177,78 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("UpdateRouter %s: updating %d affected peers: %v", router.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("UpdateRouter %s: no affected peers", router.ID)
|
||||
}
|
||||
|
||||
return router, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, []string, error) {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return nil, nil, status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return nil, nil, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
networkIDs := []string{router.NetworkID}
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
networkIDs = append(networkIDs, existing.NetworkID)
|
||||
}
|
||||
|
||||
affectedPeerIDs := m.accountManager.ResolveAffectedPeers(ctx, transaction, router.AccountID, affectedpeers.Change{NetworkIDs: networkIDs})
|
||||
|
||||
// The previous routing peer / peer-group members lose their routing role and
|
||||
// are no longer reachable from the post-update network state, so add them
|
||||
// explicitly.
|
||||
affectedPeerIDs = append(affectedPeerIDs, oldRoutingPeerIDs(ctx, transaction, router.AccountID, existing)...)
|
||||
|
||||
return network, affectedPeerIDs, nil
|
||||
}
|
||||
|
||||
// oldRoutingPeerIDs returns the peer IDs that served as the router's routing peers
|
||||
// before an update (direct Peer plus PeerGroups members).
|
||||
func oldRoutingPeerIDs(ctx context.Context, transaction store.Store, accountID string, existing *types.NetworkRouter) []string {
|
||||
var ids []string
|
||||
if existing.Peer != "" {
|
||||
ids = append(ids, existing.Peer)
|
||||
}
|
||||
if len(existing.PeerGroups) > 0 {
|
||||
groupPeers, err := transaction.GetPeerIDsByGroups(ctx, accountID, existing.PeerGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get old router peer-group members for affected peers: %v", err)
|
||||
} else {
|
||||
ids = append(ids, groupPeers...)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -208,7 +257,10 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
}
|
||||
|
||||
var event func()
|
||||
var affectedPeerIDs []string
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{NetworkIDs: []string{networkID}})
|
||||
|
||||
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete network router: %w", err)
|
||||
@@ -227,7 +279,12 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeleteRouter %s: updating %d affected peers: %v", routerID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeleteRouter %s: no affected peers", routerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -42,7 +43,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -120,19 +121,23 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
|
||||
if expired {
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// An embedded proxy peer flipping to connected is the trigger for
|
||||
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
|
||||
// tunnel IP. Without an account-wide netmap recompute, user peers keep
|
||||
// the stale synth (or no synth at all on first connect) until some
|
||||
// other change pokes the controller. Fire OnPeersUpdated so the
|
||||
// buffered recompute fans the new state out to every peer.
|
||||
// tunnel IP. Fan the change out to every peer that has a synthesized
|
||||
// policy or DNS-record edge from this proxy peer; otherwise their
|
||||
// synth state stays stale until some other change pokes the controller.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -174,11 +179,12 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
|
||||
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
|
||||
// offline, drive an account-wide netmap recompute so the synthesized
|
||||
// DNS records that pointed at it are pulled. Without this the records
|
||||
// linger client-side at TTL until something else triggers a refresh.
|
||||
// offline, refresh the peers that had synthesized records pointing at
|
||||
// it so they pull the stale entries instead of waiting out TTL.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -209,7 +215,7 @@ func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -345,7 +351,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
affectedPeerIDs = append(affectedPeerIDs, peer.ID)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -354,7 +363,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Create)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -430,7 +439,7 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
|
||||
|
||||
func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
|
||||
// todo: Create permissions for job
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -456,7 +465,7 @@ func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.RemoteJobs, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -483,7 +492,7 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID,
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -503,6 +512,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
var affectedPeerIDs []string
|
||||
|
||||
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
@@ -527,6 +537,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, []string{peerID})
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
@@ -550,7 +562,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
|
||||
}
|
||||
|
||||
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
|
||||
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
|
||||
}
|
||||
|
||||
@@ -643,7 +655,7 @@ func (am *DefaultAccountManager) handleUserAddedPeer(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
if temporary {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||
allowed, _, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -923,12 +935,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
}
|
||||
|
||||
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
if err != nil {
|
||||
return p, nmap, pc, err
|
||||
}
|
||||
|
||||
changedPeerIDs := []string{newPeer.ID}
|
||||
affectedPeerIDs := affectedPeerIDsFromNetworkMap(nmap, newPeer.ID)
|
||||
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
return p, nmap, pc, err
|
||||
return p, nmap, pc, nil
|
||||
}
|
||||
|
||||
func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
@@ -1011,7 +1029,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1141,7 +1161,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1151,6 +1173,79 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return p, nmap, pc, err
|
||||
}
|
||||
|
||||
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
|
||||
// LastLogin after a successful JWT validation. The tunnel is untouched: no
|
||||
// network map sync, no peer reconnect.
|
||||
//
|
||||
// Preconditions enforced here:
|
||||
// - userID must be present (caller validated the JWT and extracted the user ID).
|
||||
// - The peer must exist and be SSO-registered (AddedWithSSOLogin) with
|
||||
// LoginExpirationEnabled.
|
||||
// - Account-level PeerLoginExpirationEnabled must be true.
|
||||
// - The JWT user must match peer.UserID (mirrors LoginPeer at peer.go ~1028).
|
||||
//
|
||||
// Returns the new absolute UTC deadline.
|
||||
func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||
if userID == "" {
|
||||
return time.Time{}, status.Errorf(status.PermissionDenied, "session extend requires a JWT")
|
||||
}
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if !settings.PeerLoginExpirationEnabled {
|
||||
return time.Time{}, status.Errorf(status.PreconditionFailed, "peer login expiration is disabled for the account")
|
||||
}
|
||||
|
||||
var refreshed *nbpeer.Peer
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err := transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !peer.AddedWithSSOLogin() || !peer.LoginExpirationEnabled {
|
||||
return status.Errorf(status.PreconditionFailed, "peer is not eligible for session extension")
|
||||
}
|
||||
|
||||
if peer.UserID != userID {
|
||||
log.WithContext(ctx).Warnf("user mismatch when extending session for peer %s: peer user %s, jwt user %s", peer.ID, peer.UserID, userID)
|
||||
return status.NewPeerLoginMismatchError()
|
||||
}
|
||||
|
||||
peer = peer.UpdateLastLogin()
|
||||
if err := transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.SaveUserLastLogin(ctx, accountID, userID, peer.GetLastLogin()); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update user last login during session extend: %v", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.UserExtendedPeerSession, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
|
||||
refreshed = peer
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
// Reschedule the per-account expiration job. schedulePeerLoginExpiration
|
||||
// is a no-op when a job is already running, but the running job will pick
|
||||
// up the new LastLogin on its next tick. Calling it here is harmless and
|
||||
// guarantees a job is in flight even if a prior one ended right before
|
||||
// the extend.
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
|
||||
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1306,7 +1401,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -1333,6 +1428,99 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
log.WithContext(ctx).Tracef("UpdateAffectedPeers: %d peers for account %s", len(peerIDs), accountID)
|
||||
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// resolvePeerIDs resolves group IDs and direct peer IDs into a deduplicated peer ID list.
|
||||
func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string {
|
||||
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(directPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v -> %d peers: %v", groupIDs, len(peerIDs), peerIDs)
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for _, id := range directPeerIDs {
|
||||
if _, exists := seen[id]; !exists {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v + directPeers=%v -> %d peers: %v", groupIDs, directPeerIDs, len(peerIDs), peerIDs)
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (am *DefaultAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
|
||||
_ = am.networkMapController.BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// ResolveAffectedPeers resolves a description of what changed into the peer IDs
|
||||
// whose network map may have changed. It is the single entry point shared by the
|
||||
// server package and the networks managers (via the account.Manager interface).
|
||||
func (am *DefaultAccountManager) ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string {
|
||||
peerIDs, err := affectedpeers.Resolve(ctx, s, accountID, change)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve affected peers: %v", err)
|
||||
return nil
|
||||
}
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
// affectedPeerIDsFromNetworkMap returns the peer IDs referenced by a peer's
|
||||
// network map (its connected and offline peers, which include routing and proxy
|
||||
// peers), excluding the peer itself. For a freshly added peer these are, by ACL
|
||||
// symmetry, exactly the peers its addition affects.
|
||||
func affectedPeerIDsFromNetworkMap(nmap *types.NetworkMap, selfPeerID string) []string {
|
||||
if nmap == nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(nmap.Peers)+len(nmap.OfflinePeers))
|
||||
ids := make([]string, 0, len(nmap.Peers)+len(nmap.OfflinePeers))
|
||||
add := func(peers []*nbpeer.Peer) {
|
||||
for _, p := range peers {
|
||||
if p == nil || p.ID == "" || p.ID == selfPeerID {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p.ID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p.ID] = struct{}{}
|
||||
ids = append(ids, p.ID)
|
||||
}
|
||||
}
|
||||
add(nmap.Peers)
|
||||
add(nmap.OfflinePeers)
|
||||
return ids
|
||||
}
|
||||
|
||||
// resolveAffectedPeersForPeerChanges resolves changed peer IDs into the full set of affected peer IDs.
|
||||
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
|
||||
groupIDs, err := s.GetGroupIDsByPeerIDs(ctx, accountID, changedPeerIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group IDs for changed peers: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return am.ResolveAffectedPeers(ctx, s, accountID, affectedpeers.Change{
|
||||
ChangedGroupIDs: groupIDs,
|
||||
ChangedPeerIDs: changedPeerIDs,
|
||||
})
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
@@ -367,6 +367,22 @@ func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||
return timeLeft <= 0, timeLeft
|
||||
}
|
||||
|
||||
// SessionExpiresAt returns the absolute UTC instant at which the peer's SSO
|
||||
// session expires, derived from LastLogin and the account-level
|
||||
// PeerLoginExpiration setting. Returns the zero value when login expiration
|
||||
// does not apply (peer not SSO-registered, peer-level toggle off, or account
|
||||
// expiry disabled). Callers should treat the zero value as "no deadline".
|
||||
func (p *Peer) SessionExpiresAt(accountExpirationEnabled bool, expiresIn time.Duration) time.Time {
|
||||
if !accountExpirationEnabled || !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
|
||||
return time.Time{}
|
||||
}
|
||||
last := p.GetLastLogin()
|
||||
if last.IsZero() {
|
||||
return time.Time{}
|
||||
}
|
||||
return last.Add(expiresIn).UTC()
|
||||
}
|
||||
|
||||
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
|
||||
func (p *Peer) FQDN(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
|
||||
@@ -1855,7 +1855,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1880,7 +1880,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2018,7 +2018,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with route should update account peers and send peer update
|
||||
// drain any buffered updates from previous subtests
|
||||
drainPeerUpdates(updMsg)
|
||||
|
||||
// Adding peer to group linked with route should update peers in that group, not unrelated peers
|
||||
t.Run("adding peer to group linked with route", func(t *testing.T) {
|
||||
route := nbroute.Route{
|
||||
ID: "testingRoute1",
|
||||
@@ -2042,7 +2045,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2059,16 +2062,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to route should update account peers and send peer update
|
||||
// Deleting peer with linked group to route should update peers in that group, not unrelated peers
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2077,12 +2080,12 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with name server group should update account peers and send peer update
|
||||
// Adding peer to group linked with name server group should update peers in that group, not unrelated peers
|
||||
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||
@@ -2097,7 +2100,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2114,16 +2117,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to name server group should update account peers and send peer update
|
||||
// Deleting peer with linked group to name server group should update peers in that group, not unrelated peers
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2132,8 +2135,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
@@ -18,9 +19,9 @@ import (
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
|
||||
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, context.Context, error)
|
||||
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) (context.Context, error)
|
||||
|
||||
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
@@ -42,42 +43,43 @@ func (m *managerImpl) ValidateUserPermissions(
|
||||
userID string,
|
||||
module modules.Module,
|
||||
operation operations.Operation,
|
||||
) (bool, error) {
|
||||
) (bool, context.Context, error) {
|
||||
if userID == activity.SystemInitiator {
|
||||
return true, nil
|
||||
return true, ctx, nil
|
||||
}
|
||||
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, ctx, err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return false, status.NewUserNotFoundError(userID)
|
||||
return false, ctx, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
if user.IsBlocked() && !user.PendingApproval {
|
||||
return false, status.NewUserBlockedError()
|
||||
return false, ctx, status.NewUserBlockedError()
|
||||
}
|
||||
|
||||
if user.IsBlocked() && user.PendingApproval {
|
||||
return false, status.NewUserPendingApprovalError()
|
||||
return false, ctx, status.NewUserPendingApprovalError()
|
||||
}
|
||||
|
||||
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return false, err
|
||||
ctxEnriched, err := m.ValidateAccountAccess(ctx, accountID, user, false)
|
||||
if err != nil {
|
||||
return false, ctx, err
|
||||
}
|
||||
|
||||
if operation == operations.Read && user.IsServiceUser {
|
||||
return true, nil // this should be replaced by proper granular access role
|
||||
return true, ctxEnriched, nil // this should be replaced by proper granular access role
|
||||
}
|
||||
|
||||
role, ok := roles.RolesMap[user.Role]
|
||||
if !ok {
|
||||
return false, status.NewUserRoleNotFoundError(string(user.Role))
|
||||
return false, ctxEnriched, status.NewUserRoleNotFoundError(string(user.Role))
|
||||
}
|
||||
|
||||
return m.ValidateRoleModuleAccess(ctx, accountID, role, module, operation), nil
|
||||
return m.ValidateRoleModuleAccess(ctx, accountID, role, module, operation), ctxEnriched, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) ValidateRoleModuleAccess(
|
||||
@@ -98,11 +100,14 @@ func (m *managerImpl) ValidateRoleModuleAccess(
|
||||
return role.AutoAllowNew[operation]
|
||||
}
|
||||
|
||||
func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||
func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) (context.Context, error) {
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
return ctx, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
return nil
|
||||
|
||||
ctx = nbcontext.WithRole(ctx, string(user.Role))
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) {
|
||||
|
||||
@@ -67,11 +67,12 @@ func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{})
|
||||
}
|
||||
|
||||
// ValidateAccountAccess mocks base method.
|
||||
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) (context.Context, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ValidateAccountAccess", ctx, accountID, user, allowOwnerAndAdmin)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(context.Context)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ValidateAccountAccess indicates an expected call of ValidateAccountAccess.
|
||||
@@ -95,12 +96,13 @@ func (mr *MockManagerMockRecorder) ValidateRoleModuleAccess(ctx, accountID, role
|
||||
}
|
||||
|
||||
// ValidateUserPermissions mocks base method.
|
||||
func (m *MockManager) ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
|
||||
func (m *MockManager) ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, context.Context, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ValidateUserPermissions", ctx, accountID, userID, module, operation)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret1, _ := ret[1].(context.Context)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// ValidateUserPermissions indicates an expected call of ValidateUserPermissions.
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
_ "embed"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -13,13 +13,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -36,7 +37,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
if !create {
|
||||
operation = operations.Update
|
||||
}
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -45,44 +46,37 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var existingPolicy *types.Policy
|
||||
var action = activity.PolicyAdded
|
||||
var unchanged bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||
existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
if policy.Equal(existingPolicy) {
|
||||
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||
log.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||
unchanged = true
|
||||
return nil
|
||||
}
|
||||
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Policies: []*types.Policy{policy, existingPolicy}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -95,12 +89,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
policyOp := types.UpdateOperationCreate
|
||||
if isUpdate {
|
||||
policyOp = types.UpdateOperationUpdate
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Tracef("SavePolicy %s: updating %d affected peers: %v", policy.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("SavePolicy %s: no affected peers", policy.ID)
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
@@ -108,7 +101,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -117,7 +110,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
}
|
||||
|
||||
var policy *types.Policy
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
@@ -125,10 +118,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Policies: []*types.Policy{policy}})
|
||||
|
||||
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
|
||||
return err
|
||||
@@ -142,8 +132,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeletePolicy %s: updating %d affected peers: %v", policyID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeletePolicy %s: no affected peers", policyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -151,7 +144,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
// ListPolicies from the store.
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -162,46 +155,6 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules. For updates it returns
|
||||
// the existing policy loaded from the store so callers can avoid a second read.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||
|
||||
@@ -1319,12 +1319,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||
// or send peer update
|
||||
// Updating disabled policy with destination and source groups containing peers should still update account's peers
|
||||
// because affected peer resolution does not filter by policy enabled state
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
drainPeerUpdates(updMsg)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1335,8 +1337,8 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -33,9 +32,6 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
||||
peer.ID, peer.Meta.WtVersion, n.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -100,8 +100,6 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -125,7 +123,5 @@ func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, ch
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -5,18 +5,19 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -33,7 +34,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
if !create {
|
||||
operation = operations.Update
|
||||
}
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operation)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -41,9 +42,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var isUpdate = postureChecks.ID != ""
|
||||
var action = activity.PostureCheckCreated
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||
@@ -51,12 +52,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{PostureCheckIDs: []string{postureChecks.ID}})
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
@@ -76,12 +74,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
postureOp := types.UpdateOperationCreate
|
||||
if isUpdate {
|
||||
postureOp = types.UpdateOperationUpdate
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("SavePostureChecks %s: updating %d affected peers: %v", postureChecks.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("SavePostureChecks %s: no affected peers", postureChecks.ID)
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
@@ -89,7 +86,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
|
||||
// DeletePostureChecks deletes a posture check by ID.
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -126,7 +123,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
|
||||
// ListPostureChecks returns a list of posture checks.
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -137,29 +134,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// validatePostureChecks validates the posture checks.
|
||||
func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error {
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
|
||||
@@ -503,21 +503,20 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
require.NoError(t, err, "failed to save policy")
|
||||
|
||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
assert.NotEmpty(t, groupIDs)
|
||||
})
|
||||
|
||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
|
||||
assert.Empty(t, groupIDs)
|
||||
assert.Empty(t, directPeerIDs)
|
||||
})
|
||||
|
||||
t.Run("posture check does not exist", func(t *testing.T) {
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, "unknown")
|
||||
assert.Empty(t, groupIDs)
|
||||
assert.Empty(t, directPeerIDs)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||
@@ -526,9 +525,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
assert.NotEmpty(t, groupIDs)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||
@@ -537,9 +535,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
assert.NotEmpty(t, groupIDs)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
@@ -547,9 +544,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
// The collector returns groups even if they have no peers — the groups are still referenced
|
||||
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
assert.NotEmpty(t, groupIDs)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
@@ -558,8 +555,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
// Non-existent groups are filtered out during SavePolicy validation,
|
||||
// so the saved policy has empty Sources/Destinations
|
||||
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
assert.Empty(t, groupIDs)
|
||||
assert.Empty(t, directPeerIDs)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -21,7 +23,7 @@ import (
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -134,7 +136,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -147,7 +149,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
}
|
||||
|
||||
var newRoute *route.Route
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
newRoute = &route.Route{
|
||||
@@ -173,15 +175,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Routes: []*route.Route{newRoute}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -190,8 +189,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("CreateRoute %s: updating %d affected peers: %v", newRoute.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("CreateRoute %s: no affected peers", newRoute.ID)
|
||||
}
|
||||
|
||||
return newRoute, nil
|
||||
@@ -199,7 +201,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -208,8 +210,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
var oldRoute *route.Route
|
||||
var oldRouteAffectsPeers bool
|
||||
var newRouteAffectsPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
|
||||
@@ -221,21 +222,14 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Routes: []*route.Route{routeToSave, oldRoute}})
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -244,8 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("SaveRoute %s: updating %d affected peers: %v", routeToSave.ID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("SaveRoute %s: no affected peers", routeToSave.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -253,7 +250,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -261,19 +258,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var route *route.Route
|
||||
var updateAccountPeers bool
|
||||
var rt *route.Route
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Routes: []*route.Route{rt}})
|
||||
|
||||
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
|
||||
return err
|
||||
@@ -285,10 +279,13 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
log.WithContext(ctx).Debugf("DeleteRoute %s: updating %d affected peers: %v", routeID, len(affectedPeerIDs), affectedPeerIDs)
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("DeleteRoute %s: no affected peers", routeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -296,7 +293,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -377,25 +374,6 @@ func getPlaceholderIP() netip.Prefix {
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||
}
|
||||
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
|
||||
if route.Peer != "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
|
||||
@@ -1962,8 +1962,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
|
||||
// Creating a route with no routing peer and having peers in groups that don't include peer1 should not send peer1 an update
|
||||
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
|
||||
drainPeerUpdates(updMsg)
|
||||
|
||||
route := route.Route{
|
||||
ID: "testingRoute2",
|
||||
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||
@@ -1979,7 +1981,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1992,8 +1994,8 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
@@ -59,7 +59,7 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
|
||||
|
||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||
if userID != activity.SystemInitiator {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ type SetupKeyUpdateOperation struct {
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) {
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
|
||||
// DeleteSetupKey removes the setup key from the account
|
||||
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Delete)
|
||||
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user