mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-27 02:59:54 +00:00
Compare commits
4 Commits
chore/adju
...
sha-pinnin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c7bef3334 | ||
|
|
74a5fd63a7 | ||
|
|
e60a9e0e80 | ||
|
|
2b59191665 |
46
.github/actions/parse-semver/action.yml
vendored
Normal file
46
.github/actions/parse-semver/action.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Parse semver string
|
||||
description: Parse a refs/tags/vX.Y.Z[-prerelease][+build] ref into version components. Falls back to 0.0.0 when not run on a version tag.
|
||||
outputs:
|
||||
major:
|
||||
description: Major version number (e.g. "1" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.major }}
|
||||
minor:
|
||||
description: Minor version number (e.g. "2" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.minor }}
|
||||
patch:
|
||||
description: Patch version number (e.g. "3" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.patch }}
|
||||
prerelease:
|
||||
description: Prerelease identifier (e.g. "rc.1" from v1.2.3-rc.1), empty when absent.
|
||||
value: ${{ steps.parse.outputs.prerelease }}
|
||||
build:
|
||||
description: Build metadata (e.g. "build.7" from v1.2.3+build.7), empty when absent.
|
||||
value: ${{ steps.parse.outputs.build }}
|
||||
fullversion:
|
||||
description: MAJOR.MINOR.PATCH joined (e.g. "1.2.3").
|
||||
value: ${{ steps.parse.outputs.fullversion }}
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- id: parse
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
REF="${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}"
|
||||
if [[ "$REF" =~ ^refs/tags/v([0-9]+)\.([0-9]+)\.([0-9]+)(-([0-9A-Za-z.-]+))?(\+([0-9A-Za-z.-]+))?$ ]]; then
|
||||
MAJOR="${BASH_REMATCH[1]}"
|
||||
MINOR="${BASH_REMATCH[2]}"
|
||||
PATCH="${BASH_REMATCH[3]}"
|
||||
PRERELEASE="${BASH_REMATCH[5]}"
|
||||
BUILD="${BASH_REMATCH[7]}"
|
||||
else
|
||||
MAJOR=0; MINOR=0; PATCH=0; PRERELEASE=""; BUILD=""
|
||||
fi
|
||||
{
|
||||
echo "major=$MAJOR"
|
||||
echo "minor=$MINOR"
|
||||
echo "patch=$PATCH"
|
||||
echo "prerelease=$PRERELEASE"
|
||||
echo "build=$BUILD"
|
||||
echo "fullversion=$MAJOR.$MINOR.$PATCH"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
37
.github/dependabot.yml
vendored
Normal file
37
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
aws-sdk:
|
||||
patterns:
|
||||
- "github.com/aws/aws-sdk-go-v2/*"
|
||||
pion:
|
||||
patterns:
|
||||
- "github.com/pion/*"
|
||||
gorm:
|
||||
patterns:
|
||||
- "gorm.io/*"
|
||||
otel:
|
||||
patterns:
|
||||
- "go.opentelemetry.io/*"
|
||||
testcontainers:
|
||||
patterns:
|
||||
- "github.com/testcontainers/testcontainers-go/*"
|
||||
wireguard:
|
||||
patterns:
|
||||
- "golang.zx2c4.com/wireguard*"
|
||||
109
.github/workflows/check-license-dependencies.yml
vendored
109
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
@@ -19,7 +19,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
@@ -56,55 +59,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
|
||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,11 +8,10 @@ jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
discourse-tags: releases
|
||||
|
||||
8
.github/workflows/git-town.yml
vendored
8
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
- "**"
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
@@ -15,7 +15,9 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@670e1f4feb81fdef4226fc09deefe09018eb20d1 # v1.3.3
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
|
||||
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,16 +16,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -44,4 +46,3 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
21
.github/workflows/golang-test-freebsd.yml
vendored
21
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,20 +15,31 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
|
||||
138
.github/workflows/golang-test-linux.yml
vendored
138
.github/workflows/golang-test-linux.yml
vendored
@@ -18,9 +18,11 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -28,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -36,10 +38,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -113,14 +115,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -128,10 +132,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -158,14 +162,16 @@ jobs:
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -177,7 +183,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -231,10 +237,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,10 +254,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -277,14 +285,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -298,7 +308,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -324,14 +334,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -343,10 +355,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -370,19 +382,21 @@ jobs:
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -390,10 +404,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -410,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -427,7 +441,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
@@ -437,13 +451,13 @@ jobs:
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -474,10 +488,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -485,10 +501,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -505,7 +521,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -529,13 +545,13 @@ jobs:
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -566,10 +582,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -577,10 +595,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -597,7 +615,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -623,20 +641,22 @@ jobs:
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -644,10 +664,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
29
.github/workflows/golang-test-windows.yml
vendored
29
.github/workflows/golang-test-windows.yml
vendored
@@ -18,10 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -33,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -44,16 +46,23 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://pkgs.netbird.io/wintun/wintun-0.14.1.zip"
|
||||
$dest = "${env:downloadPath}\wintun.zip"
|
||||
$expected = "07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51"
|
||||
|
||||
New-Item -ItemType Directory -Force -Path (Split-Path -Parent $dest) | Out-Null
|
||||
Invoke-WebRequest $url -OutFile $dest
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "wintun.zip checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
|
||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -15,9 +15,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
@@ -38,13 +40,15 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -52,7 +56,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
|
||||
4
.github/workflows/install-script-test.yml
vendored
4
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,23 +16,25 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -52,9 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
|
||||
68
.github/workflows/proto-version-check.yml
vendored
68
.github/workflows/proto-version-check.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
@@ -20,66 +20,34 @@ jobs:
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const modifiedPbFiles = files.filter(
|
||||
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||
);
|
||||
if (modifiedPbFiles.length === 0) {
|
||||
console.log('No modified .pb.go files to check');
|
||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||
if (missingPatch.length > 0) {
|
||||
core.setFailed(
|
||||
`Cannot inspect patch data for:\n` +
|
||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const baseSha = context.payload.pull_request.base.sha;
|
||||
const headSha = context.payload.pull_request.head.sha;
|
||||
|
||||
async function getVersionHeader(path, ref) {
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
path,
|
||||
ref,
|
||||
});
|
||||
if (!res.data.content) {
|
||||
return { ok: false, reason: 'no inline content (file too large)' };
|
||||
}
|
||||
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
|
||||
const lines = content
|
||||
.split('\n')
|
||||
.slice(0, 20)
|
||||
.filter(line => versionPattern.test(line));
|
||||
return { ok: true, lines };
|
||||
} catch (e) {
|
||||
return { ok: false, reason: e.message };
|
||||
}
|
||||
}
|
||||
|
||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const violations = [];
|
||||
for (const file of modifiedPbFiles) {
|
||||
const [base, head] = await Promise.all([
|
||||
getVersionHeader(file.filename, baseSha),
|
||||
getVersionHeader(file.filename, headSha),
|
||||
]);
|
||||
if (!base.ok || !head.ok) {
|
||||
core.warning(
|
||||
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||
|
||||
for (const file of pbFiles) {
|
||||
const changed = file.patch
|
||||
.split('\n')
|
||||
.filter(line => versionPattern.test(line));
|
||||
if (changed.length > 0) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
base: base.lines,
|
||||
head: head.lines,
|
||||
lines: changed,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (violations.length > 0) {
|
||||
const details = violations.map(v =>
|
||||
`${v.file}:\n` +
|
||||
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
|
||||
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
|
||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||
).join('\n\n');
|
||||
|
||||
core.setFailed(
|
||||
|
||||
184
.github/workflows/release.yml
vendored
184
.github/workflows/release.yml
vendored
@@ -24,7 +24,9 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
@@ -51,19 +53,26 @@ jobs:
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -93,19 +102,19 @@ jobs:
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
@@ -124,26 +133,25 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -156,18 +164,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -191,7 +199,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -282,28 +290,28 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -314,27 +322,26 @@ jobs:
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -375,7 +382,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -404,7 +411,7 @@ jobs:
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -418,16 +425,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -441,7 +449,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -449,7 +457,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
@@ -474,27 +482,26 @@ jobs:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
@@ -514,29 +521,41 @@ jobs:
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://pkgs.netbird.io/wintun/wintun-0.14.1.zip"
|
||||
$dest = "${env:downloadPath}\wintun.zip"
|
||||
$expected = "07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51"
|
||||
New-Item -ItemType Directory -Force -Path (Split-Path -Parent $dest) | Out-Null
|
||||
Invoke-WebRequest $url -OutFile $dest
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "wintun.zip checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z"
|
||||
$dest = "${env:downloadPath}\mesa3d.7z"
|
||||
$expected = "71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9"
|
||||
New-Item -ItemType Directory -Force -Path (Split-Path -Parent $dest) | Out-Null
|
||||
Invoke-WebRequest $url -OutFile $dest
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "mesa3d.7z checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -547,35 +566,36 @@ jobs:
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
Invoke-WebRequest "https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip" `
|
||||
-OutFile "${{ github.workspace }}\envar_plugin.zip"
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
Invoke-WebRequest "https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z" `
|
||||
-OutFile "${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
shell: pwsh
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
run: |
|
||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
||||
Copy-Item -Destination $nsisPluginDir -Force
|
||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
@@ -592,7 +612,7 @@ jobs:
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
@@ -611,7 +631,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
@@ -703,7 +723,7 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
|
||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
|
||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
@@ -42,10 +42,10 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
26
.github/workflows/test-infrastructure-files.yml
vendored
26
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
- "infrastructure_files/**"
|
||||
- ".github/workflows/test-infrastructure-files.yml"
|
||||
- "management/cmd/**"
|
||||
- "signal/cmd/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -68,15 +68,17 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -139,8 +141,8 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
@@ -254,7 +256,9 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
paths:
|
||||
- 'shared/management/http/api/openapi.yml'
|
||||
- "shared/management/http/api/openapi.yml"
|
||||
|
||||
jobs:
|
||||
trigger_docs_api_update:
|
||||
@@ -13,10 +13,10 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger API pages generation
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: generate api pages
|
||||
repo: netbirdio/docs
|
||||
ref: "refs/heads/main"
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
15
.github/workflows/wasm-build-validation.yml
vendored
15
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,15 +19,17 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
@@ -42,9 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
@@ -65,4 +69,3 @@ jobs:
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -85,12 +84,6 @@ type Options struct {
|
||||
DisableIPv6 bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// BlockLANAccess blocks the embedded peer from reaching the host's
|
||||
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
|
||||
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
|
||||
// when the embedded client must never act as a stepping stone into
|
||||
// the host's local network (e.g. the proxy's overlay peer).
|
||||
BlockLANAccess bool
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
@@ -101,26 +94,6 @@ type Options struct {
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||
Performance Performance
|
||||
}
|
||||
|
||||
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||
//
|
||||
// These settings are process-global: any non-nil field also becomes the
|
||||
// default for Clients constructed by later embed.New calls in the same
|
||||
// process. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||
// leaves the pool unbounded. Lower values trade throughput for a
|
||||
// tighter memory ceiling. May also be changed on a running Client via
|
||||
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||
// writes per syscall, which also bounds eager buffer allocation per
|
||||
// worker. Zero uses the platform default. Applied at construction
|
||||
// only; ignored by Client.SetPerformance.
|
||||
MaxBatchSize *uint32
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -202,7 +175,6 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
DisableIPv6: &opts.DisableIPv6,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
BlockLANAccess: &opts.BlockLANAccess,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
DNSLabels: parsedLabels,
|
||||
@@ -220,13 +192,6 @@ func New(opts Options) (*Client, error) {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||
}
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
@@ -440,21 +405,6 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
}
|
||||
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
return state.PubKey, state.FQDN, true
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
@@ -523,25 +473,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||
// takes effect, and only when it was nonzero at construction;
|
||||
// MaxBatchSize is construction-only and returns an error if set here.
|
||||
//
|
||||
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||
// running yet.
|
||||
func (c *Client) SetPerformance(t Performance) error {
|
||||
if t.MaxBatchSize != nil {
|
||||
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||
}
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetPerformance(internal.Performance{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
|
||||
@@ -339,7 +339,8 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
case entry.Pattern == ".":
|
||||
return true
|
||||
case entry.IsWildcard:
|
||||
return strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
default:
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
|
||||
@@ -164,54 +164,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary mismatch (suffix overlap)",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.ab.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard multi-label match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.y.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on multi-label apex",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on unrelated suffix containment",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "notexample.com.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard accepts pattern registered without trailing dot",
|
||||
handlerDomain: "*.b.test",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -321,19 +273,6 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority matching handler should be called
|
||||
},
|
||||
{
|
||||
name: "overlapping wildcard suffixes route to correct handler",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
|
||||
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "app.ab.test.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
|
||||
@@ -26,19 +26,6 @@ type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
|
||||
// client knows about and whether that peer is currently connected. The
|
||||
// local resolver uses this to suppress A/AAAA answers whose RDATA points
|
||||
// at a disconnected peer (typical case: a synthesized private-service
|
||||
// record pointing at an embedded proxy peer that just went offline).
|
||||
//
|
||||
// known=false means the IP isn't in the local peerstore at all — the
|
||||
// record is left alone (it points at something outside our mesh, e.g.
|
||||
// a non-peer upstream).
|
||||
type PeerConnectivity interface {
|
||||
IsConnectedByIP(ip string) (known, connected bool)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
@@ -46,11 +33,6 @@ type Resolver struct {
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
// peerConn, when non-nil, is consulted on every A/AAAA answer to
|
||||
// drop records pointing at disconnected peers. nil disables the
|
||||
// filter and preserves the legacy "return whatever is registered"
|
||||
// behaviour for callers that never wire a status source.
|
||||
peerConn PeerConnectivity
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -67,15 +49,6 @@ func NewResolver() *Resolver {
|
||||
}
|
||||
}
|
||||
|
||||
// SetPeerConnectivity wires the per-IP connectivity check used to filter
|
||||
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
|
||||
// Safe to call multiple times; the latest value wins.
|
||||
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.peerConn = p
|
||||
}
|
||||
|
||||
func (d *Resolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
@@ -122,7 +95,6 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
result := d.lookupRecords(logger, question)
|
||||
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
@@ -464,78 +436,6 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
}
|
||||
}
|
||||
|
||||
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
|
||||
// a known but disconnected peer. The synthesized private-service zones
|
||||
// emit one A record per connected proxy peer in a cluster; when a peer
|
||||
// goes offline, the server-side refresh removes the record from the
|
||||
// next netmap, but the client may still hold the previous netmap for a
|
||||
// short window. This filter is the local belt to that braces — even on
|
||||
// the stale netmap, the resolver hides the offline target.
|
||||
//
|
||||
// Records pointing at unknown IPs (outside the local peerstore, e.g.
|
||||
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
|
||||
// through untouched.
|
||||
//
|
||||
// Escape hatch: if filtering would leave the answer empty AND at least
|
||||
// one record was filtered, the original list is returned. Better to
|
||||
// hand the client a record that may not respond than NXDOMAIN it
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
checker := d.peerConn
|
||||
d.mu.RUnlock()
|
||||
if checker == nil {
|
||||
return records
|
||||
}
|
||||
|
||||
kept := make([]dns.RR, 0, len(records))
|
||||
var dropped int
|
||||
for _, rr := range records {
|
||||
ip := extractRecordIP(rr)
|
||||
if ip == "" {
|
||||
kept = append(kept, rr)
|
||||
continue
|
||||
}
|
||||
known, connected := checker.IsConnectedByIP(ip)
|
||||
if known && !connected {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
kept = append(kept, rr)
|
||||
}
|
||||
if dropped == 0 {
|
||||
return records
|
||||
}
|
||||
if len(kept) == 0 {
|
||||
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
|
||||
return records
|
||||
}
|
||||
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
|
||||
return kept
|
||||
}
|
||||
|
||||
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
|
||||
// an A or AAAA record, or "" for any other record type.
|
||||
func extractRecordIP(rr dns.RR) string {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
if r.A == nil {
|
||||
return ""
|
||||
}
|
||||
return r.A.String()
|
||||
case *dns.AAAA:
|
||||
if r.AAAA == nil {
|
||||
return ""
|
||||
}
|
||||
return r.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
d.mu.Lock()
|
||||
|
||||
@@ -30,21 +30,6 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mockPeerConnectivity returns canned (known, connected) results per IP.
|
||||
// Used by the disconnected-peer filter tests below. IPs not in the map
|
||||
// are reported as unknown so the filter leaves them alone.
|
||||
type mockPeerConnectivity struct {
|
||||
byIP map[string]struct{ known, connected bool }
|
||||
}
|
||||
|
||||
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
v, ok := m.byIP[ip]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return v.known, v.connected
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -2667,114 +2652,3 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
|
||||
// connectivity-aware filtering layered on top of lookupRecords:
|
||||
// when an A record's IP belongs to a known peer that's disconnected,
|
||||
// the record is dropped from the answer. Records for unknown IPs pass
|
||||
// through. If filtering would empty the answer entirely and at least
|
||||
// one record was dropped, the original list is restored (escape hatch
|
||||
// for the "all proxies offline" case).
|
||||
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
zone := "svc.cluster.netbird."
|
||||
connectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.10",
|
||||
}
|
||||
disconnectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.11",
|
||||
}
|
||||
unknownRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "203.0.113.5",
|
||||
}
|
||||
|
||||
type ipState struct{ known, connected bool }
|
||||
tests := []struct {
|
||||
name string
|
||||
records []nbdns.SimpleRecord
|
||||
connByIP map[string]ipState
|
||||
wantInOrder []string
|
||||
}{
|
||||
{
|
||||
name: "drops disconnected peer, keeps connected",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: true},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "unknown IPs pass through untouched",
|
||||
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"203.0.113.5"},
|
||||
},
|
||||
{
|
||||
name: "all disconnected falls back to original list",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: false},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "no checker wired returns all records",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
if tc.connByIP != nil {
|
||||
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
|
||||
for ip, st := range tc.connByIP {
|
||||
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
|
||||
}
|
||||
resolver.SetPeerConnectivity(cm)
|
||||
}
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: strings.TrimSuffix(zone, "."),
|
||||
Records: tc.records,
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var got *dns.Msg
|
||||
writer := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
got = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
|
||||
resolver.ServeDNS(writer, req)
|
||||
|
||||
require.NotNil(t, got, "resolver must produce a response")
|
||||
require.Len(t, got.Answer, len(tc.wantInOrder),
|
||||
"answer count must match expected: %v", tc.wantInOrder)
|
||||
for i, want := range tc.wantInOrder {
|
||||
a, ok := got.Answer[i].(*dns.A)
|
||||
require.True(t, ok, "answer[%d] must be an A record", i)
|
||||
assert.Equal(t, want, a.A.String(),
|
||||
"answer[%d] expected %s got %s", i, want, a.A.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,11 +301,6 @@ func newDefaultServer(
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
// suppress A/AAAA answers that point at disconnected peers (typical
|
||||
// case: synthesised private-service records pointing at an embedded
|
||||
// proxy peer that just went offline).
|
||||
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
@@ -1391,25 +1386,3 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
|
||||
// the local resolver can ask "is this IP a known peer and is it
|
||||
// connected?" without taking on the peer package as a dependency.
|
||||
// A nil status recorder always reports known=false so the resolver
|
||||
// short-circuits to the legacy "return everything" path.
|
||||
type localPeerConnectivity struct {
|
||||
status *peer.Status
|
||||
}
|
||||
|
||||
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
|
||||
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
|
||||
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
if l.status == nil {
|
||||
return false, false
|
||||
}
|
||||
state, ok := l.status.PeerStateByIP(ip)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return true, state.ConnStatus == peer.StatusConnected
|
||||
}
|
||||
|
||||
@@ -1967,29 +1967,6 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||
// See Engine.SetPerformance. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetPerformance applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetPerformance(t Performance) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -66,8 +66,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, flags, err := parseRouteMessage(buf[:n])
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
@@ -66,10 +66,6 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
if systemops.IgnoreAddedDefaultRoute(flags) {
|
||||
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
|
||||
continue
|
||||
}
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
@@ -82,26 +78,22 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("parse RIB: %v", err)
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
r, err := systemops.MsgToRoute(msg)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return r, msg.Flags, nil
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||
|
||||
@@ -185,12 +185,9 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
return s.eventsChan
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays.
|
||||
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
|
||||
// every private-service request) don't contend against each other.
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
@@ -286,8 +283,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
@@ -297,8 +294,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
}
|
||||
|
||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if state.IP == ip {
|
||||
@@ -308,25 +305,6 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Returns the
|
||||
// zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
@@ -724,8 +702,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.localPeer.Clone()
|
||||
}
|
||||
|
||||
@@ -931,8 +909,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
d.rosenpassPermissive,
|
||||
@@ -940,14 +918,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
|
||||
func (d *Status) GetLazyConnection() bool {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.lazyConnectionEnabled
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return ManagementState{
|
||||
d.mgmAddress,
|
||||
d.managementState,
|
||||
@@ -973,8 +951,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
|
||||
// IsLoginRequired determines if a peer's login has expired.
|
||||
func (d *Status) IsLoginRequired() bool {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
// if peer is connected to the management then login is not expired
|
||||
if d.managementState {
|
||||
@@ -989,8 +967,8 @@ func (d *Status) IsLoginRequired() bool {
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return SignalState{
|
||||
d.signalAddress,
|
||||
d.signalState,
|
||||
@@ -1000,8 +978,8 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if d.relayMgr == nil {
|
||||
return d.relayStates
|
||||
}
|
||||
@@ -1030,8 +1008,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if d.ingressGwMgr == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1040,16 +1018,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
// shallow copy is good enough, as slices fields are currently not updated
|
||||
return slices.Clone(d.nsGroupStates)
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
@@ -1065,8 +1043,8 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||
}
|
||||
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
fullStatus.LocalPeerState = d.localPeer
|
||||
|
||||
@@ -1241,8 +1219,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
}
|
||||
|
||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if d.wgIface == nil {
|
||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||
}
|
||||
|
||||
@@ -63,33 +63,6 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
|
||||
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
|
||||
|
||||
state, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "known tunnel IP should resolve to a peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.99")
|
||||
req.False(ok, "unknown IP must report ok=false")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
state, ok := status.PeerStateByIP("fd00::1")
|
||||
req.True(ok, "IPv6-only match must resolve to the peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||
// given flags should be ignored by the network monitor.
|
||||
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||
return filterRoutesByFlags(flags)
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build darwin
|
||||
|
||||
package systemops
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
|
||||
// given flags should be ignored by the network monitor. Scoped routes
|
||||
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
|
||||
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
|
||||
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
|
||||
// must not trigger an engine restart.
|
||||
func IgnoreAddedDefaultRoute(flags int) bool {
|
||||
if filterRoutesByFlags(flags) {
|
||||
return true
|
||||
}
|
||||
if flags&unix.RTF_IFSCOPE != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3,14 +3,15 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zcalusic/sysinfo"
|
||||
|
||||
@@ -28,11 +29,19 @@ func UpdateStaticInfoAsync() {
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
kernelName, kernelVersion, kernelPlatform := kernelInfo()
|
||||
info := _getInfo()
|
||||
for strings.Contains(info, "broken pipe") {
|
||||
info = _getInfo()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
osStr := strings.ReplaceAll(info, "\n", "")
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
if osName == "" {
|
||||
osName = kernelName
|
||||
osName = osInfo[3]
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
@@ -49,8 +58,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
Kernel: kernelName,
|
||||
Platform: kernelPlatform,
|
||||
Kernel: osInfo[0],
|
||||
Platform: osInfo[2],
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
@@ -58,7 +67,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: kernelVersion,
|
||||
KernelVersion: osInfo[1],
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
@@ -69,12 +78,18 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func kernelInfo() (string, string, string) {
|
||||
var uts unix.Utsname
|
||||
if err := unix.Uname(&uts); err != nil {
|
||||
return "", "", ""
|
||||
func _getInfo() string {
|
||||
cmd := exec.Command("uname", "-srio")
|
||||
cmd.Stdin = strings.NewReader("some")
|
||||
var out bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Warnf("getInfo: %s", err)
|
||||
}
|
||||
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo() (string, string, string) {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -14,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 5 * time.Minute
|
||||
certValidationTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
@@ -47,31 +46,17 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool, 1)
|
||||
errorChan := make(chan error, 1)
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
|
||||
// Release from inside the callbacks so a post-timeout promise resolution
|
||||
// does not invoke an already-released func.
|
||||
var thenFn, catchFn js.Func
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
thenFn.Release()
|
||||
catchFn.Release()
|
||||
})
|
||||
}
|
||||
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
defer release()
|
||||
resultChan <- args[0].Bool()
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
return nil
|
||||
})
|
||||
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
defer release()
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
})
|
||||
|
||||
promise.Call("then", thenFn).Call("catch", catchFn)
|
||||
}))
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
@@ -58,8 +57,6 @@ type RDCleanPathProxy struct {
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
pendingHandlers map[string]js.Func
|
||||
nextID atomic.Uint64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -69,15 +66,8 @@ type proxyConnection struct {
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||
// global handle map and MUST be released, otherwise every connection
|
||||
// leaks the Go memory the closure captures.
|
||||
wsHandlerFn js.Func
|
||||
onMessageFn js.Func
|
||||
onCloseFn js.Func
|
||||
cleanupOnce sync.Once
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
@@ -90,11 +80,7 @@ func NewRDCleanPathProxy(client interface {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination.
|
||||
// The registered handler fn and its destinations/pendingHandlers entries are
|
||||
// only released once a connection is established and cleanupConnection runs.
|
||||
// If a caller invokes CreateProxy but never connects to the returned URL,
|
||||
// those entries stay pinned for the lifetime of the page.
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := net.JoinHostPort(hostname, port)
|
||||
|
||||
@@ -102,7 +88,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
@@ -114,7 +100,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
@@ -122,14 +108,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
})
|
||||
p.mu.Lock()
|
||||
if p.pendingHandlers == nil {
|
||||
p.pendingHandlers = make(map[string]js.Func)
|
||||
}
|
||||
p.pendingHandlers[proxyID] = handlerFn
|
||||
p.mu.Unlock()
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
|
||||
}))
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
@@ -163,10 +142,6 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
if fn, ok := p.pendingHandlers[proxyID]; ok {
|
||||
conn.wsHandlerFn = fn
|
||||
delete(p.pendingHandlers, proxyID)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
@@ -175,7 +150,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
@@ -183,15 +158,13 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoMessage", conn.onMessageFn)
|
||||
}))
|
||||
|
||||
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoClose", conn.onCloseFn)
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
@@ -288,49 +261,25 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
conn.cleanupOnce.Do(func() {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
|
||||
|
||||
// Detach before releasing so late JS calls surface as TypeError instead
|
||||
// of silent "call to released function".
|
||||
if conn.wsHandlers.Truthy() {
|
||||
conn.wsHandlers.Set("onGoMessage", js.Undefined())
|
||||
conn.wsHandlers.Set("onGoClose", js.Undefined())
|
||||
}
|
||||
|
||||
// wsHandlerFn may be zero-value if the pending handler lookup missed.
|
||||
if conn.wsHandlerFn.Truthy() {
|
||||
conn.wsHandlerFn.Release()
|
||||
}
|
||||
if conn.onMessageFn.Truthy() {
|
||||
conn.onMessageFn.Release()
|
||||
}
|
||||
if conn.onCloseFn.Truthy() {
|
||||
conn.onCloseFn.Release()
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
delete(p.pendingHandlers, conn.id)
|
||||
p.mu.Unlock()
|
||||
})
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
func CreateJSInterface(client *Client) js.Value {
|
||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
@@ -32,10 +32,9 @@ func CreateJSInterface(client *Client) js.Value {
|
||||
|
||||
_, err := client.Write(bytes)
|
||||
return js.ValueOf(err == nil)
|
||||
})
|
||||
jsInterface.Set("write", writeFunc)
|
||||
}))
|
||||
|
||||
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
@@ -43,26 +42,14 @@ func CreateJSInterface(client *Client) js.Value {
|
||||
rows := args[1].Int()
|
||||
err := client.Resize(cols, rows)
|
||||
return js.ValueOf(err == nil)
|
||||
})
|
||||
jsInterface.Set("resize", resizeFunc)
|
||||
}))
|
||||
|
||||
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
client.Close()
|
||||
return js.Undefined()
|
||||
})
|
||||
jsInterface.Set("close", closeFunc)
|
||||
}))
|
||||
|
||||
go func() {
|
||||
readLoop(client, jsInterface)
|
||||
// Detach before releasing so late JS calls surface as TypeError instead
|
||||
// of silent "call to released function".
|
||||
jsInterface.Set("write", js.Undefined())
|
||||
jsInterface.Set("resize", js.Undefined())
|
||||
jsInterface.Set("close", js.Undefined())
|
||||
writeFunc.Release()
|
||||
resizeFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
go readLoop(client, jsInterface)
|
||||
|
||||
return jsInterface
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
}
|
||||
|
||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
var relayAcceptFn func(conn listener.Conn)
|
||||
@@ -556,10 +556,6 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, id
|
||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Embedded IdP (Dex)
|
||||
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(w, r)
|
||||
|
||||
// Management HTTP API (default)
|
||||
default:
|
||||
httpHandler.ServeHTTP(w, r)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
|
||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -499,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
|
||||
@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
|
||||
@@ -5,7 +5,6 @@ package peers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -36,14 +35,6 @@ type Manager interface {
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
|
||||
// Returns nil with an error when no match exists. No permission check;
|
||||
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
|
||||
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
|
||||
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
|
||||
// to. Used by the proxy's auth path to authorise a request by the calling
|
||||
// peer's group memberships.
|
||||
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -108,26 +99,6 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP delegates to the store's indexed lookup.
|
||||
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups returns the peer plus its group memberships. Any store
|
||||
// error returns (nil, nil, err) so callers never receive a valid peer
|
||||
// alongside a non-nil error.
|
||||
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return p, groups, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,7 +6,6 @@ package peers
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -40,20 +38,6 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
// DeletePeers mocks base method.
|
||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -113,21 +97,6 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP mocks base method.
|
||||
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
|
||||
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerID mocks base method.
|
||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -143,22 +112,6 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups mocks base method.
|
||||
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].([]*types.Group)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
|
||||
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -209,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
@@ -23,8 +23,6 @@ type Domain struct {
|
||||
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
||||
// Not persisted.
|
||||
SupportsCrowdSec *bool `gorm:"-"`
|
||||
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
|
||||
SupportsPrivate *bool `gorm:"-"`
|
||||
}
|
||||
|
||||
// EventMeta returns activity event metadata for a domain
|
||||
|
||||
@@ -49,7 +49,6 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
SupportsCustomPorts: d.SupportsCustomPorts,
|
||||
RequireSubdomain: d.RequireSubdomain,
|
||||
SupportsCrowdsec: d.SupportsCrowdSec,
|
||||
SupportsPrivate: d.SupportsPrivate,
|
||||
}
|
||||
if d.TargetCluster != "" {
|
||||
resp.TargetCluster = &d.TargetCluster
|
||||
|
||||
@@ -35,7 +35,6 @@ type proxyManager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -94,7 +93,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
||||
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
@@ -111,7 +109,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
if d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
||||
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
|
||||
}
|
||||
// Custom domains never require a subdomain by default since
|
||||
// the account owns them and should be able to use the bare domain.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
@@ -40,10 +40,6 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
@@ -155,3 +151,4 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ type Manager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
|
||||
@@ -21,7 +21,6 @@ type store interface {
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
@@ -138,11 +137,6 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
|
||||
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
@@ -184,3 +178,4 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
@@ -99,9 +99,6 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
|
||||
@@ -92,20 +92,6 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate mocks base method.
|
||||
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
|
||||
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -20,9 +20,6 @@ type Capabilities struct {
|
||||
RequireSubdomain *bool
|
||||
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
||||
SupportsCrowdsec *bool
|
||||
// Private indicates whether this proxy supports inbound access via Wireguard
|
||||
// tunnel and netbird-only authentication policies
|
||||
Private *bool
|
||||
}
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
@@ -70,9 +67,10 @@ type Cluster struct {
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
|
||||
// Capability flags. *bool because nil means "no proxy reported a
|
||||
// capability for this cluster" — the dashboard renders these as
|
||||
// unknown rather than false.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
Private *bool
|
||||
}
|
||||
|
||||
@@ -204,7 +204,6 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
Private: c.Private,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,6 @@ type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -137,7 +136,6 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
|
||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
@@ -210,9 +208,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
target.Host = resource.Domain
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
case service.TargetTypeCluster:
|
||||
// Cluster targets carry the upstream address on target_id; the
|
||||
// proxy resolves the destination at request time.
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
@@ -784,10 +779,6 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case service.TargetTypeCluster:
|
||||
if err := validateClusterTarget(target); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||
}
|
||||
@@ -795,13 +786,6 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateClusterTarget(target *service.Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
@@ -978,14 +962,12 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1344,66 +1344,3 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
// No peer or resource lookups must be issued for cluster targets.
|
||||
targets := []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Options: rpservice.TargetOptions{DirectUpstream: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
|
||||
}
|
||||
|
||||
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
|
||||
// store-side check that cluster targets must opt into the host-stack dial
|
||||
// path. Without DirectUpstream the proxy would route this target through
|
||||
// the embedded NetBird client and fail on every request.
|
||||
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
targets := []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "backend.lan",
|
||||
},
|
||||
}
|
||||
err := validateTargetReferences(ctx, mockStore, accountID, targets)
|
||||
require.Error(t, err, "cluster target without direct_upstream must be rejected")
|
||||
assert.ErrorContains(t, err, "direct upstream disabled")
|
||||
}
|
||||
|
||||
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
mgr := &Manager{store: mockStore}
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-1",
|
||||
AccountID: accountID,
|
||||
Targets: []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
|
||||
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
|
||||
}
|
||||
|
||||
@@ -45,11 +45,10 @@ const (
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypeCluster TargetType = "cluster"
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
@@ -61,11 +60,6 @@ type TargetOptions struct {
|
||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
|
||||
// the target via the proxy host's network stack. Useful for upstreams
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Default false.
|
||||
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -73,7 +67,7 @@ type Target struct {
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
@@ -206,10 +200,6 @@ type Service struct {
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
PortAutoAssigned bool
|
||||
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
|
||||
Private bool
|
||||
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
|
||||
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
@@ -309,12 +299,6 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
Private: &s.Private,
|
||||
}
|
||||
|
||||
if len(s.AccessGroups) > 0 {
|
||||
groups := append([]string(nil), s.AccessGroups...)
|
||||
resp.AccessGroups = &groups
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -324,7 +308,6 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToProtoMapping converts the service into the wire format the proxy consumes.
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := s.buildPathMappings()
|
||||
|
||||
@@ -366,7 +349,6 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
Private: s.Private,
|
||||
}
|
||||
|
||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||
@@ -473,8 +455,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
|
||||
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
@@ -496,22 +477,17 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if len(opts.CustomHeaders) > 0 {
|
||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||
}
|
||||
if opts.DirectUpstream {
|
||||
apiOpts.DirectUpstream = &opts.DirectUpstream
|
||||
}
|
||||
return apiOpts
|
||||
}
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
@@ -561,9 +537,6 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
||||
if o.CustomHeaders != nil {
|
||||
opts.CustomHeaders = *o.CustomHeaders
|
||||
}
|
||||
if o.DirectUpstream != nil {
|
||||
opts.DirectUpstream = *o.DirectUpstream
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
@@ -578,14 +551,6 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
if req.ListenPort != nil {
|
||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||
}
|
||||
if req.Private != nil {
|
||||
s.Private = *req.Private
|
||||
}
|
||||
if req.AccessGroups != nil {
|
||||
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
|
||||
} else {
|
||||
s.AccessGroups = nil
|
||||
}
|
||||
|
||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||
if err != nil {
|
||||
@@ -775,9 +740,6 @@ func (s *Service) Validate() error {
|
||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.validatePrivateRequirements(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
@@ -791,23 +753,6 @@ func (s *Service) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
|
||||
func (s *Service) validatePrivateRequirements() error {
|
||||
if !s.Private {
|
||||
return nil
|
||||
}
|
||||
if s.Mode != "" && s.Mode != ModeHTTP {
|
||||
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
|
||||
}
|
||||
if len(s.AccessGroups) == 0 {
|
||||
return errors.New("private services require at least one access group")
|
||||
}
|
||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateHTTPMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
@@ -854,21 +799,11 @@ func (s *Service) validateHTTPTargets() error {
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// Host is normally overwritten by replaceHostByLookup with the
|
||||
// resolved peer IP / resource address; operator-supplied values
|
||||
// are honored only when DirectUpstream is set. Validate the
|
||||
// override here so misconfigured hosts fail fast at API time.
|
||||
if err := validateDirectUpstreamHost(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
// host field will be ignored
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
if err := validateClusterTarget(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||
}
|
||||
@@ -886,67 +821,25 @@ func (s *Service) validateHTTPTargets() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
|
||||
func validateClusterTarget(idx int, target *Target) error {
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return fmt.Errorf("target %d: has empty host", idx)
|
||||
}
|
||||
if !target.Options.DirectUpstream {
|
||||
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
|
||||
}
|
||||
return validateDirectUpstreamHost(idx, target)
|
||||
}
|
||||
|
||||
// validateDirectUpstreamHost validates the operator-supplied Host on a
|
||||
// peer/host/domain target when DirectUpstream is set. Empty Host is
|
||||
// allowed — the lookup fills in the default peer IP / resource address.
|
||||
// Without DirectUpstream the Host value is silently overwritten by
|
||||
// replaceHostByLookup, so we don't validate it (preserves the historical
|
||||
// behaviour where APIs accepted any value and dropped it). Non-empty
|
||||
// Host with DirectUpstream must look like a hostname or IP and must
|
||||
// not carry a port (port lives on Target.Port).
|
||||
func validateDirectUpstreamHost(idx int, target *Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.ContainsAny(host, " \t/") {
|
||||
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
// L4 services have a single target; per-target disable is meaningless
|
||||
// (use the service-level Enabled flag instead). Force it on so that
|
||||
// buildPathMappings always includes the target in the proto.
|
||||
target.Enabled = true
|
||||
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return errors.New("target_id is required for L4 services")
|
||||
}
|
||||
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
if err := validateDirectUpstreamHost(0, target); err != nil {
|
||||
return err
|
||||
}
|
||||
// OK
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return errors.New("target host is required for subnet targets")
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
// target_id carries the cluster address; the proxy resolves
|
||||
// the upstream at request time.
|
||||
default:
|
||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||
}
|
||||
@@ -1281,11 +1174,6 @@ func (s *Service) Copy() *Service {
|
||||
}
|
||||
}
|
||||
|
||||
var accessGroups []string
|
||||
if len(s.AccessGroups) > 0 {
|
||||
accessGroups = append([]string(nil), s.AccessGroups...)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
@@ -1307,8 +1195,6 @@ func (s *Service) Copy() *Service {
|
||||
Mode: s.Mode,
|
||||
ListenPort: s.ListenPort,
|
||||
PortAutoAssigned: s.PortAutoAssigned,
|
||||
Private: s.Private,
|
||||
AccessGroups: accessGroups,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -1117,191 +1116,3 @@ func TestValidate_HeaderAuths(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
|
||||
}
|
||||
|
||||
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
|
||||
// rule that operator-supplied Host is mandatory: cluster targets dial the
|
||||
// upstream via the host network stack (direct_upstream is implied), so an
|
||||
// empty Host leaves the proxy with nothing to dial.
|
||||
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
|
||||
}
|
||||
|
||||
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
|
||||
// half of the cluster-target rule: DirectUpstream must be true so the
|
||||
// stdlib transport branch in MultiTransport is taken. Without it the
|
||||
// embedded NetBird client would try to dial the cluster address through
|
||||
// the WG tunnel, which is the wrong network for a cluster upstream.
|
||||
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
|
||||
}
|
||||
|
||||
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Mode = ModeTCP
|
||||
rp.ListenPort = 9000
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||
}
|
||||
|
||||
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||
svc := validProxy()
|
||||
svc.Private = true
|
||||
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
|
||||
cp := svc.Copy()
|
||||
require.NotNil(t, cp)
|
||||
assert.True(t, cp.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
|
||||
|
||||
cp.Private = false
|
||||
assert.True(t, svc.Private)
|
||||
|
||||
cp.AccessGroups[0] = "grp-other"
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
|
||||
}
|
||||
|
||||
func TestService_APIRoundtrip_Private(t *testing.T) {
|
||||
enabled := true
|
||||
private := true
|
||||
accessGroups := []string{"grp-admins"}
|
||||
targets := []api.ServiceTarget{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: api.ServiceTargetTargetType("cluster"),
|
||||
Protocol: "http",
|
||||
Port: 80,
|
||||
Enabled: true,
|
||||
}}
|
||||
req := &api.ServiceRequest{
|
||||
Name: "svc-private",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
Enabled: enabled,
|
||||
Private: &private,
|
||||
AccessGroups: &accessGroups,
|
||||
Targets: &targets,
|
||||
}
|
||||
|
||||
svc := &Service{}
|
||||
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
|
||||
assert.True(t, svc.Private)
|
||||
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
|
||||
|
||||
resp := svc.ToAPIResponse()
|
||||
require.NotNil(t, resp.Private)
|
||||
assert.True(t, *resp.Private)
|
||||
require.NotNil(t, resp.AccessGroups)
|
||||
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
|
||||
}
|
||||
|
||||
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "access group")
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-sso"},
|
||||
}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Mode = ModeTCP
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "HTTP")
|
||||
}
|
||||
|
||||
@@ -20,20 +20,6 @@ type KeyPair struct {
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
Method auth.Method `json:"method"`
|
||||
// Email is the calling user's email address. Carried so the
|
||||
// proxy can stamp identity on upstream requests (e.g.
|
||||
// x-litellm-end-user-id) without an extra management
|
||||
// round-trip on every cookie-bearing request.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Groups carries the user's group IDs so the proxy can stamp them
|
||||
// onto upstream requests (X-NetBird-Groups) from the cookie path
|
||||
// without an extra management round-trip.
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
// GroupNames carries the human-readable display names for the ids
|
||||
// in Groups, ordered identically (positional pairing). Slice may be
|
||||
// shorter than Groups for tokens minted before names were
|
||||
// resolvable; the consumer falls back to ids for missing positions.
|
||||
GroupNames []string `json:"group_names,omitempty"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (*KeyPair, error) {
|
||||
@@ -48,13 +34,7 @@ func GenerateKeyPair() (*KeyPair, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignToken mints a session JWT for the given user and domain. email,
|
||||
// groups, and groupNames, when non-empty, are embedded so the proxy can
|
||||
// authorise and stamp identity for policy-aware middlewares without a
|
||||
// management round-trip on every cookie-bearing request. groupNames
|
||||
// pairs positionally with groups; pass nil when names couldn't be
|
||||
// resolved.
|
||||
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
|
||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode private key: %w", err)
|
||||
@@ -76,10 +56,7 @@ func SignToken(privKeyB64, userID, email, domain string, method auth.Method, gro
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
Method: method,
|
||||
Email: email,
|
||||
Groups: append([]string(nil), groups...),
|
||||
GroupNames: append([]string(nil), groupNames...),
|
||||
Method: method,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||
|
||||
@@ -10,10 +10,8 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/rs/cors"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -21,6 +19,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
@@ -28,20 +27,16 @@ import (
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
@@ -99,17 +94,12 @@ func (s *BaseServer) Store() store.Store {
|
||||
|
||||
func (s *BaseServer) EventStore() activity.Store {
|
||||
return Create(s, func() activity.Store {
|
||||
var err error
|
||||
key := s.Config.DataStoreEncryptionKey
|
||||
if key == "" {
|
||||
log.Debugf("generate new activity store encryption key")
|
||||
key, err = crypt.GenerateKey()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||
}
|
||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
}
|
||||
|
||||
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
@@ -120,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -128,22 +118,6 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||
// the deployment isn't using the embedded variant.
|
||||
func (s *BaseServer) IDPHandler() http.Handler {
|
||||
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||
if !ok || embeddedIdP == nil {
|
||||
return nil
|
||||
}
|
||||
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||
}
|
||||
|
||||
func (s *BaseServer) Router() *mux.Router {
|
||||
return Create(s, func() *mux.Router {
|
||||
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -39,7 +38,7 @@ func (s *BaseServer) JobManager() *job.Manager {
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := validator.NewIntegratedValidator(
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
context.Background(),
|
||||
s.PeersManager(),
|
||||
s.SettingsManager(),
|
||||
|
||||
@@ -57,7 +57,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
|
||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||
return Create(s, func() permissions.Manager {
|
||||
return permissions.NewManager(s.Store())
|
||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetAccountManager(s.AccountManager())
|
||||
})
|
||||
|
||||
return manager
|
||||
})
|
||||
}
|
||||
|
||||
@@ -147,6 +153,7 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return idpManager
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -228,7 +235,3 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
||||
log.Tracef("custom handler set successfully")
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
// Check if a custom handler was set (for multiplexing additional services)
|
||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||
if handler, ok := customHandler.(http.Handler); ok {
|
||||
@@ -318,8 +318,6 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
||||
gRPCHandler.ServeHTTP(writer, request)
|
||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||
wsProxy.Handler().ServeHTTP(writer, request)
|
||||
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(writer, request)
|
||||
default:
|
||||
httpHandler.ServeHTTP(writer, request)
|
||||
}
|
||||
|
||||
@@ -351,7 +351,6 @@ func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
Private: c.Private,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -755,11 +754,6 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
|
||||
connUpdate = filterMappingsForProxy(conn, connUpdate)
|
||||
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
|
||||
return true
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
@@ -888,20 +882,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
}
|
||||
|
||||
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
|
||||
// Private mappings require SupportsPrivateService; custom-port L4 mappings
|
||||
// require SupportsCustomPorts. Remove operations always pass so proxies can
|
||||
// clean up.
|
||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
||||
// mappings with a custom listen port, since they don't understand the
|
||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
return true
|
||||
}
|
||||
if mapping.GetPrivate() {
|
||||
caps := conn.capabilities
|
||||
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||
return true
|
||||
}
|
||||
@@ -910,29 +900,6 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||
}
|
||||
|
||||
// filterMappingsForProxy drops mappings the proxy cannot safely receive
|
||||
// (e.g. private mappings to a proxy without SupportsPrivateService).
|
||||
// Returns the input unchanged when no filtering is needed.
|
||||
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
if update == nil || len(update.Mapping) == 0 {
|
||||
return update
|
||||
}
|
||||
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, m := range update.Mapping {
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, m)
|
||||
}
|
||||
if len(kept) == len(update.Mapping) {
|
||||
return update
|
||||
}
|
||||
return &proto.GetMappingUpdateResponse{
|
||||
Mapping: kept,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
@@ -994,10 +961,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
|
||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||
|
||||
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
|
||||
// secrets and have no user-level group context, so groups stay nil. Email
|
||||
// is also empty — these schemes don't resolve a user record at sign time.
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1086,7 +1050,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -1094,11 +1058,8 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
token, err := sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
userEmail,
|
||||
service.Domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1109,26 +1070,6 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
|
||||
// into parallel id and name slices. ids[i] and names[i] always pair to
|
||||
// the same group. nil entries (orphan ids the manager couldn't resolve)
|
||||
// are skipped so the consumer can rely on positional pairing.
|
||||
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
|
||||
if len(groups) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ids = make([]string, 0, len(groups))
|
||||
names = make([]string, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, g.ID)
|
||||
names = append(names, g.Name)
|
||||
}
|
||||
return ids, names
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
@@ -1393,9 +1334,7 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
return verifier, redirectURL, nil
|
||||
}
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and
|
||||
// user. The user's group memberships are embedded in the token so policy-aware
|
||||
// middlewares on the proxy can authorise without an extra management round-trip.
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
@@ -1406,29 +1345,11 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
|
||||
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||
}
|
||||
|
||||
var (
|
||||
email string
|
||||
groupIDs []string
|
||||
groupNames []string
|
||||
)
|
||||
if s.usersManager != nil {
|
||||
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if uerr != nil {
|
||||
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
|
||||
} else if user != nil {
|
||||
email = user.Email
|
||||
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
|
||||
}
|
||||
}
|
||||
|
||||
return sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userID,
|
||||
email,
|
||||
domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
}
|
||||
@@ -1532,7 +1453,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1545,7 +1466,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
user, err := s.usersManager.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1579,15 +1500,12 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"user_id": userID,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateSession: access denied")
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1597,13 +1515,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"email": user.Email,
|
||||
}).Debug("ValidateSession: access granted")
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1636,154 +1551,3 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the peer's group membership against the service's access groups.
|
||||
// Peers without a user (machine agents, automation workloads) are first-class
|
||||
// callers; authorisation runs off peer-group memberships rather than the
|
||||
// optional owning user's auto-groups. On success a session JWT is minted so
|
||||
// the proxy can install a cookie and skip subsequent management round-trips.
|
||||
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
domain := req.GetDomain()
|
||||
tunnelIPStr := req.GetTunnelIp()
|
||||
|
||||
if domain == "" || tunnelIPStr == "" {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "missing domain or tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
tunnelIP := net.ParseIP(tunnelIPStr)
|
||||
if tunnelIP == nil {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "invalid_tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "service_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
|
||||
// validate and mint session cookies for their own account's domains.
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
|
||||
if err != nil || peer == nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"tunnel_ip": tunnelIPStr,
|
||||
"peer_id": peer.ID,
|
||||
"principal_id": principalID,
|
||||
}).Debug("ValidateTunnelPeer: access granted")
|
||||
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
SessionToken: token,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
// boundary and must not trust upstream state). Bearer-auth services authorise
|
||||
// against DistributionGroups when populated. Non-private non-bearer services
|
||||
// are open.
|
||||
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
|
||||
if service.Private {
|
||||
if len(service.AccessGroups) == 0 {
|
||||
return fmt.Errorf("private service has no access groups")
|
||||
}
|
||||
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
|
||||
}
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
|
||||
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
|
||||
// else a non-nil error.
|
||||
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
|
||||
if len(allowedGroups) == 0 {
|
||||
return fmt.Errorf("no allowed groups configured")
|
||||
}
|
||||
allowed := make(map[string]struct{}, len(allowedGroups))
|
||||
for _, g := range allowedGroups {
|
||||
allowed[g] = struct{}{}
|
||||
}
|
||||
for _, g := range peerGroupIDs {
|
||||
if _, ok := allowed[g]; ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("peer not in allowed groups")
|
||||
}
|
||||
|
||||
@@ -129,14 +129,6 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -428,46 +420,3 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPeerGroupAccess(t *testing.T) {
|
||||
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: nil}
|
||||
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no access groups")
|
||||
})
|
||||
|
||||
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
|
||||
})
|
||||
|
||||
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-allowed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("non-private non-bearer is open", func(t *testing.T) {
|
||||
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
// Send immediately (first update or after quiet period)
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||
|
||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
@@ -125,7 +125,6 @@ func TestValidateSession_UserAllowed(t *testing.T) {
|
||||
assert.True(t, resp.Valid, "User should be allowed access")
|
||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||
assert.Empty(t, resp.DeniedReason)
|
||||
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
@@ -146,7 +145,6 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||
|
||||
@@ -15,13 +15,15 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -30,10 +32,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -52,14 +56,17 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -93,16 +100,25 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
isValidChildAccount,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
prefix := apiPrefix
|
||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Check if embedded IdP is enabled for instance manager
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
@@ -138,5 +154,10 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
|
||||
oauthHandler.RegisterEndpoints(router)
|
||||
}
|
||||
|
||||
return router, nil
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
||||
}
|
||||
|
||||
return rootRouter, nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -25,8 +27,6 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager serverauth.Manager
|
||||
@@ -35,7 +35,6 @@ type AuthMiddleware struct {
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
isValidChildAccount IsValidChildAccountFunc
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -46,7 +45,6 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
isValidChildAccount IsValidChildAccountFunc,
|
||||
) *AuthMiddleware {
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
@@ -64,7 +62,6 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
isValidChildAccount: isValidChildAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +124,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
@@ -206,7 +203,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
|
||||
@@ -211,7 +211,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -271,7 +270,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -324,7 +322,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -368,7 +365,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -413,7 +409,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -478,7 +473,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -538,7 +532,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -594,7 +587,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -695,7 +687,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
@@ -136,8 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -266,8 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type IntegratedValidatorImpl struct{}
|
||||
|
||||
func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) {
|
||||
return &IntegratedValidatorImpl{}, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer {
|
||||
return peer.Copy()
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for _, p := range peers {
|
||||
validatedPeers[p.ID] = struct{}{}
|
||||
}
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) {
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) Stop(_ context.Context) {
|
||||
}
|
||||
|
||||
func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow {
|
||||
return flowResponse
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -54,7 +53,6 @@ type DataSource interface {
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetStoreEngine() types.Engine
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
|
||||
}
|
||||
|
||||
// ConnManager peer connection manager that holds state for current active connections
|
||||
@@ -225,12 +223,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
servicesAuthPassword int
|
||||
servicesAuthPin int
|
||||
servicesAuthOIDC int
|
||||
// Private-service signals — track adoption of NetBird-only mode
|
||||
// (services backed by an embedded proxy peer + access groups).
|
||||
servicesPrivate int
|
||||
servicesPrivateWithGroups int
|
||||
servicesPrivateAccessGroupsSum int
|
||||
servicesWithDirectUpstream int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -388,31 +380,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||
servicesAuthOIDC++
|
||||
}
|
||||
|
||||
if service.Private {
|
||||
servicesPrivate++
|
||||
if len(service.AccessGroups) > 0 {
|
||||
servicesPrivateWithGroups++
|
||||
}
|
||||
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
|
||||
}
|
||||
|
||||
for _, target := range service.Targets {
|
||||
if target.Options.DirectUpstream {
|
||||
servicesWithDirectUpstream++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Proxy / BYOP cluster signals come from the proxies table aggregated
|
||||
// across all accounts in a single store query; nil on FileStore.
|
||||
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
|
||||
}
|
||||
|
||||
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
||||
metricsProperties["uptime"] = uptime
|
||||
metricsProperties["accounts"] = accounts
|
||||
@@ -460,15 +430,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["services_auth_password"] = servicesAuthPassword
|
||||
metricsProperties["services_auth_pin"] = servicesAuthPin
|
||||
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
||||
metricsProperties["services_private"] = servicesPrivate
|
||||
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
|
||||
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
|
||||
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
|
||||
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
|
||||
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
|
||||
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
|
||||
metricsProperties["proxies"] = proxyMetrics.Proxies
|
||||
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
|
||||
metricsProperties["custom_domains"] = customDomains
|
||||
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -124,7 +123,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "peer"},
|
||||
{TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||
{TargetType: "host"},
|
||||
},
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||
@@ -142,16 +141,6 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||
},
|
||||
{
|
||||
ID: "svc3-private",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
AccessGroups: []string{"grp-eng", "grp-ops"},
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -265,18 +254,6 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e
|
||||
return 3, 2, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics returns canned proxy/cluster counts so the
|
||||
// generateProperties test can assert the BYOP signals end-to-end.
|
||||
func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) {
|
||||
return store.ProxyMetrics{
|
||||
Clusters: 3,
|
||||
ClustersBYOP: 1,
|
||||
ClustersPrivate: 1,
|
||||
Proxies: 4,
|
||||
ProxiesConnected: 2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||
func TestGenerateProperties(t *testing.T) {
|
||||
ds := mockDatasource{}
|
||||
@@ -416,17 +393,17 @@ func TestGenerateProperties(t *testing.T) {
|
||||
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||
}
|
||||
|
||||
if properties["services"] != 3 {
|
||||
t.Errorf("expected 3 services, got %v", properties["services"])
|
||||
if properties["services"] != 2 {
|
||||
t.Errorf("expected 2 services, got %v", properties["services"])
|
||||
}
|
||||
if properties["services_enabled"] != 2 {
|
||||
t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"])
|
||||
if properties["services_enabled"] != 1 {
|
||||
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
|
||||
}
|
||||
if properties["services_targets"] != 4 {
|
||||
t.Errorf("expected 4 services_targets, got %v", properties["services_targets"])
|
||||
if properties["services_targets"] != 3 {
|
||||
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
|
||||
}
|
||||
if properties["services_status_active"] != 2 {
|
||||
t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"])
|
||||
if properties["services_status_active"] != 1 {
|
||||
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
|
||||
}
|
||||
if properties["services_status_pending"] != 1 {
|
||||
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
||||
@@ -443,9 +420,6 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["services_target_type_domain"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
||||
}
|
||||
if properties["services_target_type_cluster"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"])
|
||||
}
|
||||
if properties["services_auth_password"] != 1 {
|
||||
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
||||
}
|
||||
@@ -455,33 +429,6 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["services_auth_pin"] != 0 {
|
||||
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
||||
}
|
||||
if properties["services_private"] != 1 {
|
||||
t.Errorf("expected 1 services_private, got %v", properties["services_private"])
|
||||
}
|
||||
if properties["services_private_with_access_groups"] != 1 {
|
||||
t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"])
|
||||
}
|
||||
if properties["services_private_access_groups_sum"] != 2 {
|
||||
t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"])
|
||||
}
|
||||
if properties["services_with_direct_upstream"] != 2 {
|
||||
t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"])
|
||||
}
|
||||
if properties["proxy_clusters"] != int64(3) {
|
||||
t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"])
|
||||
}
|
||||
if properties["proxy_clusters_byop"] != int64(1) {
|
||||
t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"])
|
||||
}
|
||||
if properties["proxy_clusters_private"] != int64(1) {
|
||||
t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"])
|
||||
}
|
||||
if properties["proxies"] != int64(4) {
|
||||
t.Errorf("expected 4 proxies, got %v", properties["proxies"])
|
||||
}
|
||||
if properties["proxies_connected"] != int64(2) {
|
||||
t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"])
|
||||
}
|
||||
if properties["custom_domains"] != int64(3) {
|
||||
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
||||
}
|
||||
|
||||
@@ -125,18 +125,6 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
}
|
||||
|
||||
// An embedded proxy peer flipping to connected is the trigger for
|
||||
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
|
||||
// tunnel IP. Without an account-wide netmap recompute, user peers keep
|
||||
// the stale synth (or no synth at all on first connect) until some
|
||||
// other change pokes the controller. Fire OnPeersUpdated so the
|
||||
// buffered recompute fans the new state out to every peer.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -172,17 +160,6 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
return nil
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
|
||||
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
|
||||
// offline, drive an account-wide netmap recompute so the synthesized
|
||||
// DNS records that pointed at it are pulled. Without this the records
|
||||
// linger client-side at TTL until something else triggers a refresh.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -32,6 +33,9 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
||||
peer.ID, peer.Meta.WtVersion, n.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -100,6 +100,8 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -123,5 +125,7 @@ func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, ch
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -274,9 +274,3 @@ func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
|
||||
func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics is a no-op for FileStore — proxy/cluster state isn't
|
||||
// persisted in the JSON file format.
|
||||
func (s *FileStore) GetProxyMetrics(_ context.Context) (ProxyMetrics, error) {
|
||||
return ProxyMetrics{}, nil
|
||||
}
|
||||
|
||||
@@ -1090,38 +1090,6 @@ func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, er
|
||||
return total, validated, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics aggregates per-cluster + per-proxy counts for the
|
||||
// self-hosted telemetry payload. Single round-trip via conditional
|
||||
// aggregations so a large proxies table doesn't fan out into multiple
|
||||
// queries.
|
||||
func (s *SqlStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) {
|
||||
var m ProxyMetrics
|
||||
activeCutoff := time.Now().Add(-proxyActiveThreshold)
|
||||
|
||||
// COUNT(DISTINCT ... CASE WHEN ...) is portable across sqlite/postgres
|
||||
// (MySQL too) and keeps the round-trip to one. proxy.StatusConnected
|
||||
// is the same string the cluster-capability queries use; the active
|
||||
// window matches the cluster-capability semantics (only proxies
|
||||
// heartbeating within ~2 * heartbeat interval count as connected).
|
||||
row := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Select(
|
||||
"COUNT(DISTINCT cluster_address) AS clusters, "+
|
||||
"COUNT(DISTINCT CASE WHEN account_id IS NOT NULL THEN cluster_address END) AS clusters_byop, "+
|
||||
"COUNT(DISTINCT CASE WHEN private = ? THEN cluster_address END) AS clusters_private, "+
|
||||
"COUNT(*) AS proxies, "+
|
||||
"COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS proxies_connected",
|
||||
true,
|
||||
proxy.StatusConnected,
|
||||
activeCutoff,
|
||||
).
|
||||
Row()
|
||||
if err := row.Scan(&m.Clusters, &m.ClustersBYOP, &m.ClustersPrivate, &m.Proxies, &m.ProxiesConnected); err != nil {
|
||||
return ProxyMetrics{}, fmt.Errorf("scan proxy metrics: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
||||
var accounts []types.Account
|
||||
result := s.db.Find(&accounts)
|
||||
@@ -2210,8 +2178,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated,
|
||||
private, access_groups
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated
|
||||
FROM services WHERE account_id = $1`
|
||||
|
||||
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
||||
@@ -2226,11 +2193,10 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var accessGroups []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
var mode, source, sourcePeer sql.NullString
|
||||
var terminated, portAutoAssigned, private sql.NullBool
|
||||
var terminated, portAutoAssigned sql.NullBool
|
||||
var listenPort sql.NullInt64
|
||||
err := row.Scan(
|
||||
&s.ID,
|
||||
@@ -2253,8 +2219,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
&source,
|
||||
&sourcePeer,
|
||||
&terminated,
|
||||
&private,
|
||||
&accessGroups,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2266,16 +2230,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessGroups) > 0 {
|
||||
if err := json.Unmarshal(accessGroups, &s.AccessGroups); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal access_groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if private.Valid {
|
||||
s.Private = private.Bool
|
||||
}
|
||||
|
||||
s.Meta = rpservice.Meta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
@@ -5872,7 +5826,6 @@ var validCapabilityColumns = map[string]struct{}{
|
||||
"supports_custom_ports": {},
|
||||
"require_subdomain": {},
|
||||
"supports_crowdsec": {},
|
||||
"private": {},
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
@@ -5887,12 +5840,6 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
|
||||
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate reports whether any active proxy in the cluster
|
||||
// has the private capability (nil = unreported).
|
||||
func (s *SqlStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return s.getClusterCapability(ctx, clusterAddr, "private")
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
|
||||
// have CrowdSec configured. Returns nil when no proxy reported the capability.
|
||||
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
)
|
||||
|
||||
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
account := newAccountWithId(ctx, "account_private_svc", "testuser", "")
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-private",
|
||||
AccountID: account.Id,
|
||||
Name: "private-svc",
|
||||
Domain: "private.example",
|
||||
ProxyCluster: "cluster.example",
|
||||
Enabled: true,
|
||||
Mode: rpservice.ModeHTTP,
|
||||
Private: true,
|
||||
AccessGroups: []string{"grp-admins", "grp-ops"},
|
||||
}
|
||||
require.NoError(t, store.CreateService(ctx, svc))
|
||||
|
||||
loaded, err := store.GetAccount(ctx, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, loaded.Services, 1)
|
||||
|
||||
got := loaded.Services[0]
|
||||
assert.True(t, got.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, got.AccessGroups)
|
||||
})
|
||||
}
|
||||
@@ -312,7 +312,6 @@ type Store interface {
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
@@ -321,38 +320,9 @@ type Store interface {
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
|
||||
// GetProxyMetrics returns aggregated proxy / cluster counts for the
|
||||
// self-hosted metrics worker. Self-hosted only — file-based stores
|
||||
// return a zero-valued struct.
|
||||
GetProxyMetrics(ctx context.Context) (ProxyMetrics, error)
|
||||
|
||||
GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error)
|
||||
}
|
||||
|
||||
// ProxyMetrics aggregates self-hosted proxy + cluster usage signals
|
||||
// surfaced to the telemetry payload. Each field is best-effort: when a
|
||||
// store cannot answer (e.g. FileStore) all fields are zero.
|
||||
type ProxyMetrics struct {
|
||||
// Clusters counts distinct cluster_address values across the proxies
|
||||
// table — every cluster the management server has heard from, online or not.
|
||||
Clusters int64
|
||||
// ClustersBYOP counts distinct cluster_address values that are owned
|
||||
// by an account (account_id IS NOT NULL). These are bring-your-own-proxy
|
||||
// installations as opposed to NetBird-operated shared clusters.
|
||||
ClustersBYOP int64
|
||||
// ClustersPrivate counts distinct cluster_address values where at
|
||||
// least one proxy reported the private capability (embedded
|
||||
// `netbird proxy` running inside a client).
|
||||
ClustersPrivate int64
|
||||
// Proxies is the total number of proxy rows currently persisted.
|
||||
Proxies int64
|
||||
// ProxiesConnected is the subset of proxies whose status is
|
||||
// "connected" AND last_seen falls within the active heartbeat window
|
||||
// (~2 * heartbeat interval). Proxies the controller hasn't pruned
|
||||
// yet but that are visibly stale don't count.
|
||||
ProxiesConnected int64
|
||||
}
|
||||
|
||||
const (
|
||||
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
|
||||
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
|
||||
@@ -1461,20 +1461,6 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate indicates an expected call of GetClusterSupportsPrivate.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsPrivate", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetCustomDomain mocks base method.
|
||||
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2090,21 +2076,6 @@ func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetProxyMetrics mocks base method.
|
||||
func (m *MockStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyMetrics", ctx)
|
||||
ret0, _ := ret[0].(ProxyMetrics)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyMetrics indicates an expected call of GetProxyMetrics.
|
||||
func (mr *MockStoreMockRecorder) GetProxyMetrics(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyMetrics", reflect.TypeOf((*MockStore)(nil).GetProxyMetrics), ctx)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -32,9 +32,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTTL = 300
|
||||
// privateServiceDNSRecordTTL is short so proxy-peer changes propagate quickly to clients.
|
||||
privateServiceDNSRecordTTL = 5
|
||||
defaultTTL = 300
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||
|
||||
@@ -256,117 +254,6 @@ func getUniqueHostLabel(name string, peerLabels LookupMap) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SynthesizePrivateServiceZones returns in-memory CustomZones with A records pointing each enabled private service the peer can reach at the cluster's proxy-peer IPs. One zone per cluster (multiple services share); records gated by AccessGroups.
|
||||
func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZone {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok || peer == nil {
|
||||
return nil
|
||||
}
|
||||
if len(a.Services) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyPeersByCluster := a.GetProxyPeers()
|
||||
if len(proxyPeersByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
continue
|
||||
}
|
||||
if len(svc.AccessGroups) == 0 {
|
||||
continue
|
||||
}
|
||||
if !peerInDistributionGroups(peerGroups, svc.AccessGroups) {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyPeersByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the cluster apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
skippedDisconnected := 0
|
||||
for _, p := range proxyPeers {
|
||||
if p == nil || !p.IP.IsValid() {
|
||||
continue
|
||||
}
|
||||
// Only emit a record when the proxy peer is actually
|
||||
// connected. A disconnected proxy peer's tunnel IP won't
|
||||
// answer; pointing DNS at it would produce a black hole
|
||||
// for as long as the record is cached client-side.
|
||||
if p.Status == nil || !p.Status.Connected {
|
||||
skippedDisconnected++
|
||||
continue
|
||||
}
|
||||
zone.Records = append(zone.Records, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(svc.Domain),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: privateServiceDNSRecordTTL,
|
||||
RData: p.IP.String(),
|
||||
})
|
||||
emitted++
|
||||
}
|
||||
// Disagreement with the firewall path is the typical
|
||||
// "domain doesn't reach client but firewall rules do"
|
||||
// symptom: the synth service is otherwise fine, only the
|
||||
// proxy peer's persisted Connected flag is wrong (most
|
||||
// likely the connection reaper marked it disconnected even
|
||||
// though the gRPC stream is alive).
|
||||
if emitted == 0 && skippedDisconnected > 0 {
|
||||
log.Debugf("private-zone synth: svc %s domain=%s cluster=%s emitted_zero proxy_peers=%d all_disconnected=%d (firewall would still fire)",
|
||||
svc.ID, svc.Domain, svc.ProxyCluster, len(proxyPeers), skippedDisconnected)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, *zone)
|
||||
}
|
||||
if len(out) == 0 && len(a.Services) > 0 {
|
||||
// Targeted diagnostic for the "firewall yes, DNS no" divergence —
|
||||
// fires only when services exist but synth returns zero zones,
|
||||
// so accounts without private services produce no noise.
|
||||
log.Debugf("private-zone synth: peer %s account %s returned 0 zones from %d candidate service(s)",
|
||||
peerID, a.Id, len(a.Services))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
for _, gid := range distributionGroups {
|
||||
if _, ok := peerGroups[gid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||
var merr *multierror.Error
|
||||
|
||||
@@ -1611,53 +1498,6 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *servi
|
||||
a.injectTargetProxyPolicies(ctx, service, target, proxyPeers)
|
||||
}
|
||||
|
||||
a.injectPrivateServicePolicies(service, proxyPeers)
|
||||
}
|
||||
|
||||
// injectPrivateServicePolicies synthesises an in-memory ACL: AccessGroups → cluster proxy peers on TCP 80/443.
|
||||
func (a *Account) injectPrivateServicePolicies(svc *service.Service, proxyPeers []*nbpeer.Peer) {
|
||||
if !svc.Private {
|
||||
return
|
||||
}
|
||||
if len(svc.AccessGroups) == 0 {
|
||||
return
|
||||
}
|
||||
if len(proxyPeers) == 0 {
|
||||
return
|
||||
}
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
a.Policies = append(a.Policies, a.createPrivateServicePolicy(svc, proxyPeer))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createPrivateServicePolicy(svc *service.Service, proxyPeer *nbpeer.Peer) *Policy {
|
||||
policyID := fmt.Sprintf("private-access-%s-%s", svc.ID, proxyPeer.ID)
|
||||
sources := append([]string(nil), svc.AccessGroups...)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Private Access to %s", svc.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access groups to reach %s", svc.Name),
|
||||
Enabled: true,
|
||||
Sources: sources,
|
||||
DestinationResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 443, End: 443},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
|
||||
|
||||
@@ -119,7 +119,6 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
components.AccountZones = append(components.AccountZones, a.SynthesizePrivateServiceZones(peerID)...)
|
||||
|
||||
for _, nsGroup := range a.NameServerGroups {
|
||||
if nsGroup.Enabled {
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestPrivateService_NetworkMap_UserPeer_AndProxyPeer(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["user-peer"].Meta.WtVersion = "0.50.0"
|
||||
account.Peers["proxy-peer"].Meta.WtVersion = "0.50.0"
|
||||
|
||||
ctx := context.Background()
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
t.Run("user-peer update", func(t *testing.T) {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok)
|
||||
require.Len(t, zone.Records, 1)
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", zone.Records[0].Name)
|
||||
assert.Equal(t, int(dns.TypeA), zone.Records[0].Type)
|
||||
assert.Equal(t, "100.64.0.99", zone.Records[0].RData)
|
||||
|
||||
assert.Contains(t, netmapPeerIDs(nm.Peers), "proxy-peer")
|
||||
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.99", FirewallRuleDirectionOUT)
|
||||
})
|
||||
|
||||
t.Run("proxy-peer update", func(t *testing.T) {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "proxy-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
assert.Contains(t, netmapPeerIDs(nm.Peers), "user-peer")
|
||||
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.10", FirewallRuleDirectionIN)
|
||||
})
|
||||
}
|
||||
|
||||
func netmapPeerIDs(peers []*nbpeer.Peer) []string {
|
||||
ids := make([]string, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, p.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func assertPrivateServiceFirewallRules(t *testing.T, rules []*FirewallRule, peerIP string, direction int) {
|
||||
t.Helper()
|
||||
wantPorts := map[uint16]bool{80: false, 443: false}
|
||||
for _, r := range rules {
|
||||
if r == nil || r.PeerIP != peerIP || r.Direction != direction {
|
||||
continue
|
||||
}
|
||||
if r.Protocol != string(PolicyRuleProtocolTCP) || r.Action != string(PolicyTrafficActionAccept) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case r.PortRange.Start == r.PortRange.End && r.PortRange.Start != 0:
|
||||
wantPorts[r.PortRange.Start] = true
|
||||
case r.Port == "80":
|
||||
wantPorts[80] = true
|
||||
case r.Port == "443":
|
||||
wantPorts[443] = true
|
||||
}
|
||||
}
|
||||
for port, found := range wantPorts {
|
||||
assert.Truef(t, found, "missing TCP accept rule on port %d for peer %s direction %d", port, peerIP, direction)
|
||||
}
|
||||
}
|
||||
@@ -1,256 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func privateZoneTestAccount(t *testing.T) *Account {
|
||||
t.Helper()
|
||||
return &Account{
|
||||
Id: "acct-1",
|
||||
Settings: &Settings{},
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.10"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.99"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PeerInGroup_GetsRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "one cluster should produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "eu.proxy.netbird.io.", zone.Domain, "zone apex must be the cluster FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must be match-only so unrelated sibling names fall through to the upstream resolver")
|
||||
require.Len(t, zone.Records, 1, "one private service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
|
||||
assert.Equal(t, int(dns.TypeA), rec.Type, "record type must be A")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
assert.Equal(t, privateServiceDNSRecordTTL, rec.TTL, "TTL must match the synth-records constant")
|
||||
assert.Equal(t, nbdns.DefaultClass, rec.Class, "record class must be the package default")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PeerNotInGroup_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Groups["grp-admins"].Peers = nil
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "peer outside distribution_groups must not see private-service records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NotPrivate_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].Private = false
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "non-private service must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NoAccessGroups_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].AccessGroups = nil
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "private service without bearer auth must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NoProxyPeers_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
delete(account.Peers, "proxy-peer")
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "no embedded proxy peer in cluster means no record to emit")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_DisabledService_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].Enabled = false
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "disabled service must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_DisconnectedProxyPeer_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer"].Status = &nbpeer.PeerStatus{Connected: false}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "disconnected proxy peer must not produce a DNS record (would be a black hole)")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PartiallyDisconnectedProxyPeers_OnlyConnectedSurface(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
|
||||
ID: "proxy-peer-2",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-2-key",
|
||||
IP: netip.MustParseAddr("100.64.0.100"),
|
||||
Status: &nbpeer.PeerStatus{Connected: false},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1)
|
||||
require.Len(t, zones[0].Records, 1, "only the connected proxy peer must surface")
|
||||
assert.Equal(t, "100.64.0.99", zones[0].Records[0].RData)
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MultipleProxyPeers_RoundRobin(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
|
||||
ID: "proxy-peer-2",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-2-key",
|
||||
IP: netip.MustParseAddr("100.64.0.100"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "still one cluster yields one zone")
|
||||
require.Len(t, zones[0].Records, 2, "two proxy peers must produce two A records on the same name")
|
||||
rdata := []string{zones[0].Records[0].RData, zones[0].Records[1].RData}
|
||||
assert.ElementsMatch(t, []string{"100.64.0.99", "100.64.0.100"}, rdata, "both proxy peer IPs must surface")
|
||||
}
|
||||
|
||||
// findCustomZone returns the CustomZone whose Domain equals the FQDN
|
||||
// of want, or a zero value when not found. Tests use it to assert
|
||||
// that the synth zone reaches dnsUpdate.CustomZones end-to-end.
|
||||
func findCustomZone(zones []nbdns.CustomZone, want string) (nbdns.CustomZone, bool) {
|
||||
wantFqdn := dns.Fqdn(want)
|
||||
for _, z := range zones {
|
||||
if z.Domain == wantFqdn {
|
||||
return z, true
|
||||
}
|
||||
}
|
||||
return nbdns.CustomZone{}, false
|
||||
}
|
||||
|
||||
// TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone
|
||||
// covers the components-based builder path. The components builder
|
||||
// appends SynthesizePrivateServiceZones to AccountZones; the
|
||||
// CalculateNetworkMapFromComponents step then merges AccountZones
|
||||
// into dnsUpdate.CustomZones.
|
||||
func TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
ctx := context.Background()
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm, "network map must be produced for an in-account peer")
|
||||
|
||||
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "shipped CustomZones must include the synth zone for the cluster")
|
||||
require.Len(t, zone.Records, 1, "exactly one record per private service per connected proxy peer")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
// TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone
|
||||
// confirms the negative case the user encountered: a peer whose
|
||||
// groups don't overlap the policy's distribution_groups gets a
|
||||
// network map with no synth zone (and the wildcard / peer zones still
|
||||
// flow through). This is the test mirror of the runtime confusion
|
||||
// where the user looked at a non-distribution-group peer and assumed
|
||||
// the synth path was broken.
|
||||
func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["outsider"] = &nbpeer.Peer{
|
||||
ID: "outsider",
|
||||
AccountID: "acct-1",
|
||||
Key: "outsider-key",
|
||||
IP: netip.MustParseAddr("100.64.0.20"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
}
|
||||
ctx := context.Background()
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
"outsider": {},
|
||||
}
|
||||
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "outsider", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
_, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "anotherapp",
|
||||
Domain: "anotherapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services on the same cluster must collapse into one zone")
|
||||
require.Len(t, zones[0].Records, 2, "two services yield two A records")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package types
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -84,9 +82,9 @@ func setupTestAccount() *Account {
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"groupAll": {
|
||||
ID: "groupAll",
|
||||
Name: "All",
|
||||
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
|
||||
ID: "groupAll",
|
||||
Name: "All",
|
||||
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
|
||||
Issued: GroupIssuedAPI,
|
||||
},
|
||||
"group1": {
|
||||
@@ -1585,203 +1583,3 @@ func Test_filterPeerAppliedZones(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_ProxyPeerGetsInboundRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
userPeerIP := netip.MustParseAddr("100.64.0.10")
|
||||
proxyPeerIP := netip.MustParseAddr("100.64.0.99")
|
||||
|
||||
account := &Account{
|
||||
Id: "acct-1",
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: userPeerIP,
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: proxyPeerIP,
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
Targets: []*service.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: service.TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
var found *Policy
|
||||
for _, p := range account.Policies {
|
||||
if p != nil && p.ID == "private-access-svc-1-proxy-peer" {
|
||||
found = p
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "expected synthesised private-access policy in account.Policies")
|
||||
require.Len(t, found.Rules, 1, "policy should have exactly one rule")
|
||||
rule := found.Rules[0]
|
||||
assert.Equal(t, []string{"grp-admins"}, rule.Sources, "sources should be group IDs verbatim")
|
||||
assert.Equal(t, "proxy-peer", rule.DestinationResource.ID, "destination resource should be the proxy peer ID")
|
||||
assert.Equal(t, ResourceTypePeer, rule.DestinationResource.Type, "destination resource type should be peer")
|
||||
|
||||
validatedPeersMap := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
proxyPeer := account.Peers["proxy-peer"]
|
||||
aclPeers, firewallRules, _, _ := account.GetPeerConnectionResources(ctx, proxyPeer, validatedPeersMap, nil)
|
||||
|
||||
var sawUserAsAclPeer bool
|
||||
for _, p := range aclPeers {
|
||||
if p.ID == "user-peer" {
|
||||
sawUserAsAclPeer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, sawUserAsAclPeer, "proxy peer should see the user peer as an ACL peer")
|
||||
|
||||
var inboundRules []*FirewallRule
|
||||
for _, r := range firewallRules {
|
||||
if r.Direction == FirewallRuleDirectionIN && r.PeerIP == userPeerIP.String() {
|
||||
inboundRules = append(inboundRules, r)
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, inboundRules, "proxy peer should have inbound firewall rules from the user peer")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_NotPrivate_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
account.Services[0].Private = false
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "non-private service must not synthesise an access policy")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_EmptyAccessGroups_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
account.Services[0].AccessGroups = nil
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "private service with no access groups must not synthesise a policy")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_NoProxyPeers_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
delete(account.Peers, "proxy-peer")
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "policy must not synthesise when the cluster has no proxy peers")
|
||||
}
|
||||
|
||||
func privateServiceTestAccount(t *testing.T) *Account {
|
||||
t.Helper()
|
||||
return &Account{
|
||||
Id: "acct-1",
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.10"),
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.99"),
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
Targets: []*service.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: service.TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func hasPrivateAccessPolicy(account *Account, serviceID string) bool {
|
||||
prefix := "private-access-" + serviceID + "-"
|
||||
for _, p := range account.Policies {
|
||||
if p != nil && len(p.ID) > len(prefix) && p.ID[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -762,7 +762,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
|
||||
// Ensure the initiator still has admin privileges
|
||||
if !freshInitiator.HasAdminPower() {
|
||||
if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() {
|
||||
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
|
||||
}
|
||||
initiatorUser = freshInitiator
|
||||
@@ -906,23 +906,19 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
|
||||
return nil
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
|
||||
}
|
||||
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
if initiatorUser.Role != types.UserRoleOwner && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||
}
|
||||
if oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||
}
|
||||
if initiatorUser.Role != types.UserRoleOwner && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||
}
|
||||
if oldUser.IsServiceUser && update.Role == types.UserRoleOwner {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type Manager interface {
|
||||
GetUser(ctx context.Context, userID string) (*types.User, error)
|
||||
GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -30,31 +29,6 @@ func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User,
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
}
|
||||
|
||||
// GetUserWithGroups returns the user and the *types.Group records for the user's AutoGroups, in the same order as
|
||||
// AutoGroups. Group ids that don't resolve to a stored group are skipped from the returned slice (the parallel id list is
|
||||
// derivable from the returned User). Wraps two store calls today; can be optimised to a single JOIN later if needed.
|
||||
// Any store error returns (nil, nil, err) so callers never receive a valid user alongside a non-nil error.
|
||||
func (m *managerImpl) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(user.AutoGroups) == 0 {
|
||||
return user, nil, nil
|
||||
}
|
||||
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups := make([]*types.Group, 0, len(user.AutoGroups))
|
||||
for _, id := range user.AutoGroups {
|
||||
if g, ok := groupsMap[id]; ok && g != nil {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
}
|
||||
return user, groups, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &managerMock{}
|
||||
}
|
||||
@@ -73,11 +47,3 @@ func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User,
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerMock) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
@@ -45,14 +45,10 @@ func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string {
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateSessionJWT validates a session JWT and returns the user ID, the
|
||||
// user's email (when carried), the authentication method, any embedded
|
||||
// group memberships, and the parallel group display names. email,
|
||||
// groups, and groupNames may be empty for tokens minted before those
|
||||
// claims were introduced. groupNames pairs positionally with groups.
|
||||
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, email, method string, groups, groupNames []string, err error) {
|
||||
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
|
||||
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
|
||||
if publicKey == nil {
|
||||
return "", "", "", nil, nil, fmt.Errorf("no public key configured for domain")
|
||||
return "", "", fmt.Errorf("no public key configured for domain")
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
|
||||
@@ -62,46 +58,20 @@ func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey)
|
||||
return publicKey, nil
|
||||
}, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer))
|
||||
if err != nil {
|
||||
return "", "", "", nil, nil, fmt.Errorf("parse token: %w", err)
|
||||
return "", "", fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", "", nil, nil, fmt.Errorf("invalid token claims")
|
||||
return "", "", fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
sub, _ := claims.GetSubject()
|
||||
if sub == "" {
|
||||
return "", "", "", nil, nil, fmt.Errorf("missing subject claim")
|
||||
return "", "", fmt.Errorf("missing subject claim")
|
||||
}
|
||||
|
||||
methodClaim, _ := claims["method"].(string)
|
||||
emailClaim, _ := claims["email"].(string)
|
||||
groups = extractGroupsClaim(claims["groups"])
|
||||
groupNames = extractGroupsClaim(claims["group_names"])
|
||||
|
||||
return sub, emailClaim, methodClaim, groups, groupNames, nil
|
||||
}
|
||||
|
||||
// extractGroupsClaim decodes the "groups" claim into a string slice. The JWT
|
||||
// library decodes JSON arrays as []interface{}, so we coerce element-wise
|
||||
// and skip non-string entries silently.
|
||||
func extractGroupsClaim(claim interface{}) []string {
|
||||
raw, ok := claim.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
groups := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
groups = append(groups, s)
|
||||
}
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
return groups
|
||||
return sub, methodClaim, nil
|
||||
}
|
||||
|
||||
@@ -109,22 +109,6 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugPerfCmd = &cobra.Command{
|
||||
Use: "perf <pool-cap>",
|
||||
Short: "Live-retune the tunnel buffer pool cap on all running clients",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runDebugPerfSet,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugRuntimeCmd = &cobra.Command{
|
||||
Use: "runtime",
|
||||
Short: "Show runtime stats (heap, goroutines, RSS)",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runDebugRuntime,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugCaptureCmd = &cobra.Command{
|
||||
Use: "capture <account-id> [filter expression]",
|
||||
Short: "Capture packets on a client's WireGuard interface",
|
||||
@@ -175,8 +159,6 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugCmd.AddCommand(debugPerfCmd)
|
||||
debugCmd.AddCommand(debugRuntimeCmd)
|
||||
debugCmd.AddCommand(debugCaptureCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
@@ -238,18 +220,6 @@ func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugPerfSet(cmd *cobra.Command, args []string) error {
|
||||
n, err := strconv.ParseUint(args[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value %q: %w", args[0], err)
|
||||
}
|
||||
return getDebugClient(cmd).PerfSet(cmd.Context(), uint32(n))
|
||||
}
|
||||
|
||||
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
|
||||
return getDebugClient(cmd).Runtime(cmd.Context())
|
||||
}
|
||||
|
||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||
duration, _ := cmd.Flags().GetDuration("duration")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
|
||||
@@ -15,22 +15,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// envPreallocatedBuffers caps the per-tunnel buffer pool. Zero (unset)
|
||||
// keeps the upstream uncapped default.
|
||||
envPreallocatedBuffers = "NB_PROXY_PREALLOCATED_BUFFERS"
|
||||
// envMaxBatchSize overrides the per-tunnel batch size, which controls
|
||||
// how many buffers each receive/TUN worker eagerly allocates. Zero
|
||||
// (unset) keeps the platform default.
|
||||
envMaxBatchSize = "NB_PROXY_MAX_BATCH_SIZE"
|
||||
)
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
// envProxyToken is the environment variable name for the proxy access token.
|
||||
@@ -74,7 +63,6 @@ var (
|
||||
preSharedKey string
|
||||
supportsCustomPorts bool
|
||||
requireSubdomain bool
|
||||
private bool
|
||||
geoDataDir string
|
||||
crowdsecAPIURL string
|
||||
crowdsecAPIKey string
|
||||
@@ -117,8 +105,6 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
|
||||
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
|
||||
rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain")
|
||||
rootCmd.Flags().BoolVar(&private, "private", envBoolOrDefault("NB_PROXY_PRIVATE", false), "Enable private services accessible with NetBird-Only authentication mode.")
|
||||
_ = rootCmd.Flags().MarkHidden("private")
|
||||
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
|
||||
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
|
||||
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
|
||||
@@ -159,45 +145,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
|
||||
logger.Infof("configured log level: %s", level)
|
||||
|
||||
var wgPool, wgBatch uint64
|
||||
var perf embed.Performance
|
||||
if raw := os.Getenv(envPreallocatedBuffers); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envPreallocatedBuffers, raw, err)
|
||||
}
|
||||
wgPool = n
|
||||
v := uint32(n)
|
||||
perf.PreallocatedBuffersPerPool = &v
|
||||
logger.Infof("tunnel preallocated buffers per pool: %d", n)
|
||||
}
|
||||
if raw := os.Getenv(envMaxBatchSize); raw != "" {
|
||||
n, err := strconv.ParseUint(raw, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid %s %q: %w", envMaxBatchSize, raw, err)
|
||||
}
|
||||
wgBatch = n
|
||||
v := uint32(n)
|
||||
perf.MaxBatchSize = &v
|
||||
logger.Infof("tunnel max batch size override: %d", n)
|
||||
}
|
||||
if wgPool > 0 {
|
||||
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
|
||||
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
|
||||
// the lifetime of the Device. A pool cap below that floor blocks
|
||||
// the receive pipeline at startup.
|
||||
batch := wgBatch
|
||||
if batch == 0 {
|
||||
batch = 128
|
||||
}
|
||||
const recvGoroutines = 4
|
||||
floor := batch * recvGoroutines
|
||||
if wgPool < floor {
|
||||
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
|
||||
envPreallocatedBuffers, wgPool, floor, batch)
|
||||
}
|
||||
}
|
||||
|
||||
switch forwardedProto {
|
||||
case "auto", "http", "https":
|
||||
default:
|
||||
@@ -214,8 +161,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid --trusted-proxies: %w", err)
|
||||
}
|
||||
|
||||
srv := proxy.New(proxy.Config{
|
||||
ListenAddr: addr,
|
||||
srv := proxy.Server{
|
||||
Logger: logger,
|
||||
Version: Version,
|
||||
ManagementAddress: mgmtAddr,
|
||||
@@ -232,24 +178,22 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
ACMEChallengeType: acmeChallengeType,
|
||||
DebugEndpointEnabled: debugEndpoint,
|
||||
DebugEndpointAddress: debugEndpointAddr,
|
||||
HealthAddr: healthAddr,
|
||||
HealthAddress: healthAddr,
|
||||
ForwardedProto: forwardedProto,
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WildcardCertDir: wildcardCertDir,
|
||||
WireguardPort: wgPort,
|
||||
Performance: perf,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
PreSharedKey: preSharedKey,
|
||||
SupportsCustomPorts: supportsCustomPorts,
|
||||
RequireSubdomain: requireSubdomain,
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
})
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
547
proxy/inbound.go
547
proxy/inbound.go
@@ -1,547 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// httpInboundReadHeaderTimeout matches the host-listener read header timeout
|
||||
// so per-account http.Servers don't leak idle connections.
|
||||
const httpInboundReadHeaderTimeout = 30 * time.Second
|
||||
|
||||
// httpInboundIdleTimeout caps idle keep-alive on per-account inbound HTTP
|
||||
// servers; matches the host listener.
|
||||
const httpInboundIdleTimeout = 90 * time.Second
|
||||
|
||||
// inboundShutdownTimeout caps how long a per-account http.Server gets to
|
||||
// drain in-flight requests during teardown.
|
||||
const inboundShutdownTimeout = 5 * time.Second
|
||||
|
||||
// privateInboundPortHTTPS is the WG-side TLS port. Each account's
|
||||
// embedded netstack binds independently, so a fixed port is fine.
|
||||
const privateInboundPortHTTPS = 443
|
||||
|
||||
// privateInboundPortHTTP is the WG-side plain-HTTP port.
|
||||
const privateInboundPortHTTP = 80
|
||||
|
||||
// inboundManager wires per-account inbound listeners into the proxy
|
||||
// pipeline when --private-inbound is enabled. When disabled the manager
|
||||
// is nil and every method on *Server that touches it short-circuits.
|
||||
type inboundManager struct {
|
||||
logger *log.Logger
|
||||
handler http.Handler
|
||||
tlsConfig *tls.Config
|
||||
// muxLock guards entries and pendingRoutes.
|
||||
muxLock sync.Mutex
|
||||
entries map[types.AccountID]*inboundEntry
|
||||
pendingRoutes map[types.AccountID][]pendingInboundRoute
|
||||
}
|
||||
|
||||
// inboundEntry owns the listeners, router and HTTP servers for a single
|
||||
// account's embedded netstack.
|
||||
type inboundEntry struct {
|
||||
router *nbtcp.Router
|
||||
tlsListener net.Listener
|
||||
plainListener net.Listener
|
||||
httpsServer *http.Server
|
||||
httpServer *http.Server
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// pendingInboundRoute holds a route that arrived before the account's
|
||||
// listener finished starting.
|
||||
type pendingInboundRoute struct {
|
||||
host nbtcp.SNIHost
|
||||
route nbtcp.Route
|
||||
}
|
||||
|
||||
// newInboundManager constructs a manager bound to the proxy's HTTP
|
||||
// handler chain and TLS config.
|
||||
func newInboundManager(logger *log.Logger, handler http.Handler, tlsConfig *tls.Config) *inboundManager {
|
||||
return &inboundManager{
|
||||
logger: logger,
|
||||
handler: handler,
|
||||
tlsConfig: tlsConfig,
|
||||
entries: make(map[types.AccountID]*inboundEntry),
|
||||
pendingRoutes: make(map[types.AccountID][]pendingInboundRoute),
|
||||
}
|
||||
}
|
||||
|
||||
// onClientReady is registered with NetBird.SetClientLifecycle so the
|
||||
// listener pair comes up exactly when the embedded client reports ready.
|
||||
// The returned value is opaque to the roundtrip package; it is handed
|
||||
// back verbatim to onClientStop on teardown.
|
||||
func (m *inboundManager) onClientReady(ctx context.Context, accountID types.AccountID, client *embed.Client) any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
entry, err := m.bringUp(ctx, accountID, client)
|
||||
if err != nil {
|
||||
m.logger.WithField("account_id", accountID).WithError(err).Warn("failed to start per-account inbound listener; continuing without inbound")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.flushPending(accountID, entry)
|
||||
|
||||
m.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"https": entry.tlsListener.Addr().String(),
|
||||
"http": entry.plainListener.Addr().String(),
|
||||
}).Info("per-account inbound listeners up")
|
||||
return entry
|
||||
}
|
||||
|
||||
// onClientStop tears down a per-account listener bundle. State is the
|
||||
// opaque value previously returned by onClientReady.
|
||||
func (m *inboundManager) onClientStop(accountID types.AccountID, state any) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
entry, ok := state.(*inboundEntry)
|
||||
if !ok || entry == nil {
|
||||
return
|
||||
}
|
||||
m.tearDown(accountID, entry)
|
||||
}
|
||||
|
||||
// bringUp opens both listeners on the account's netstack, builds the
|
||||
// router, and starts the parallel HTTP servers.
|
||||
func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID, client *embed.Client) (*inboundEntry, error) {
|
||||
tlsListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTPS))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen tls on netstack: %w", err)
|
||||
}
|
||||
plainListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTP))
|
||||
if err != nil {
|
||||
_ = tlsListener.Close()
|
||||
return nil, fmt.Errorf("listen plain on netstack: %w", err)
|
||||
}
|
||||
|
||||
router := nbtcp.NewRouter(m.logger, accountDialResolver(accountID, client), tlsListener.Addr(), nbtcp.WithPlainHTTP(plainListener.Addr()))
|
||||
|
||||
scopedHandler := withTunnelLookup(m.handler, accountTunnelLookup(client))
|
||||
|
||||
// markOverlayOrigin stamps every connection accepted by an inbound
|
||||
// listener with a context value middlewares can read to skip
|
||||
// geo/CrowdSec checks (the source address is always inside the
|
||||
// NetBird CGNAT range and won't match either dataset).
|
||||
markOverlayOrigin := func(ctx context.Context, _ net.Conn) context.Context {
|
||||
return types.WithOverlayOrigin(ctx)
|
||||
}
|
||||
|
||||
httpsServer := &http.Server{
|
||||
Handler: scopedHandler,
|
||||
TLSConfig: m.tlsConfig,
|
||||
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
|
||||
IdleTimeout: httpInboundIdleTimeout,
|
||||
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
|
||||
ConnContext: markOverlayOrigin,
|
||||
}
|
||||
httpServer := &http.Server{
|
||||
Handler: scopedHandler,
|
||||
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
|
||||
IdleTimeout: httpInboundIdleTimeout,
|
||||
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
|
||||
ConnContext: markOverlayOrigin,
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
entry := &inboundEntry{
|
||||
router: router,
|
||||
tlsListener: tlsListener,
|
||||
plainListener: plainListener,
|
||||
httpsServer: httpsServer,
|
||||
httpServer: httpServer,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
if err := router.Serve(runCtx, tlsListener); err != nil {
|
||||
m.logger.WithField("account_id", accountID).Debugf("per-account router stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
if err := httpsServer.ServeTLS(router.HTTPListener(), "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.WithField("account_id", accountID).Debugf("per-account https server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
if err := httpServer.Serve(router.HTTPListenerPlain()); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.WithField("account_id", accountID).Debugf("per-account http server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
feedRouterFromListener(runCtx, plainListener, router, m.logger, accountID)
|
||||
}()
|
||||
|
||||
m.muxLock.Lock()
|
||||
m.entries[accountID] = entry
|
||||
m.muxLock.Unlock()
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// tearDown shuts every goroutine down and closes the netstack listeners.
|
||||
func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry) {
|
||||
m.muxLock.Lock()
|
||||
if m.entries[accountID] == entry {
|
||||
delete(m.entries, accountID)
|
||||
delete(m.pendingRoutes, accountID)
|
||||
}
|
||||
m.muxLock.Unlock()
|
||||
|
||||
entry.cancel()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), inboundShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := entry.httpsServer.Shutdown(shutdownCtx); err != nil {
|
||||
m.logger.Debugf("per-account https shutdown: %v", err)
|
||||
}
|
||||
if err := entry.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
m.logger.Debugf("per-account http shutdown: %v", err)
|
||||
}
|
||||
if err := entry.tlsListener.Close(); err != nil {
|
||||
m.logger.Debugf("close per-account tls listener: %v", err)
|
||||
}
|
||||
if err := entry.plainListener.Close(); err != nil {
|
||||
m.logger.Debugf("close per-account plain listener: %v", err)
|
||||
}
|
||||
entry.wg.Wait()
|
||||
}
|
||||
|
||||
// AddRoute records an SNI/host route on the account's per-account router.
|
||||
// Routes registered before the listener is up are queued and replayed
|
||||
// once startup completes.
|
||||
func (m *inboundManager) AddRoute(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
entry, ok := m.entries[accountID]
|
||||
if !ok {
|
||||
m.queuePendingLocked(accountID, host, route)
|
||||
m.muxLock.Unlock()
|
||||
return
|
||||
}
|
||||
router := entry.router
|
||||
m.muxLock.Unlock()
|
||||
|
||||
router.AddRoute(host, route)
|
||||
}
|
||||
|
||||
// RemoveRoute drops a previously registered route. Safe to call when the
|
||||
// listener is not yet up; queued copies are pruned in that case.
|
||||
func (m *inboundManager) RemoveRoute(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
m.dropPendingLocked(accountID, host, svcID)
|
||||
entry, ok := m.entries[accountID]
|
||||
if !ok {
|
||||
m.muxLock.Unlock()
|
||||
return
|
||||
}
|
||||
router := entry.router
|
||||
m.muxLock.Unlock()
|
||||
|
||||
router.RemoveRoute(host, svcID)
|
||||
}
|
||||
|
||||
// queuePendingLocked stores or upserts a pending route. Caller holds muxLock.
|
||||
func (m *inboundManager) queuePendingLocked(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
|
||||
queued := m.pendingRoutes[accountID]
|
||||
for i, pr := range queued {
|
||||
if pr.host == host && pr.route.ServiceID == route.ServiceID {
|
||||
queued[i] = pendingInboundRoute{host: host, route: route}
|
||||
m.pendingRoutes[accountID] = queued
|
||||
return
|
||||
}
|
||||
}
|
||||
m.pendingRoutes[accountID] = append(queued, pendingInboundRoute{host: host, route: route})
|
||||
}
|
||||
|
||||
// dropPendingLocked removes any queued route matching host/svcID.
|
||||
// Caller holds muxLock.
|
||||
func (m *inboundManager) dropPendingLocked(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
|
||||
queued, ok := m.pendingRoutes[accountID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
filtered := queued[:0]
|
||||
for _, pr := range queued {
|
||||
if pr.host == host && pr.route.ServiceID == svcID {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, pr)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(m.pendingRoutes, accountID)
|
||||
return
|
||||
}
|
||||
m.pendingRoutes[accountID] = filtered
|
||||
}
|
||||
|
||||
// flushPending applies all queued routes to a freshly-up router.
|
||||
func (m *inboundManager) flushPending(accountID types.AccountID, entry *inboundEntry) {
|
||||
m.muxLock.Lock()
|
||||
queued := m.pendingRoutes[accountID]
|
||||
delete(m.pendingRoutes, accountID)
|
||||
m.muxLock.Unlock()
|
||||
|
||||
for _, pr := range queued {
|
||||
entry.router.AddRoute(pr.host, pr.route)
|
||||
}
|
||||
}
|
||||
|
||||
// HasInbound reports whether the manager has a live listener for the account.
|
||||
// Used by tests.
|
||||
func (m *inboundManager) HasInbound(accountID types.AccountID) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
_, ok := m.entries[accountID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// PendingRouteCount reports the number of queued routes for the account.
|
||||
// Used by tests.
|
||||
func (m *inboundManager) PendingRouteCount(accountID types.AccountID) int {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
return len(m.pendingRoutes[accountID])
|
||||
}
|
||||
|
||||
// InboundListenerInfo describes the bound addresses of a single
|
||||
// per-account inbound listener. Both addresses live on the embedded
|
||||
// netstack of the account's WireGuard client and share the same tunnel IP.
|
||||
type InboundListenerInfo struct {
|
||||
TunnelIP string
|
||||
HTTPSPort uint16
|
||||
HTTPPort uint16
|
||||
}
|
||||
|
||||
// ListenerInfo returns the inbound listener addresses for the given
|
||||
// account, or ok=false when the account has no live listener. Used by
|
||||
// the status-update RPC and the debug HTTP handler to surface inbound
|
||||
// reachability to operators.
|
||||
func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListenerInfo, bool) {
|
||||
if m == nil {
|
||||
return InboundListenerInfo{}, false
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
entry, ok := m.entries[accountID]
|
||||
if !ok || entry == nil {
|
||||
return InboundListenerInfo{}, false
|
||||
}
|
||||
return listenerInfoFromEntry(entry), true
|
||||
}
|
||||
|
||||
// Snapshot returns the inbound listener state for every account that has
|
||||
// a live listener at call time. Empty when --private-inbound is off or
|
||||
// no accounts have come up yet.
|
||||
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
if len(m.entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[types.AccountID]InboundListenerInfo, len(m.entries))
|
||||
for id, entry := range m.entries {
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
out[id] = listenerInfoFromEntry(entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// listenerInfoFromEntry extracts the tunnel IP and ports from a live
|
||||
// per-account entry. Both listeners are bound on the same netstack so
|
||||
// their host components match; we still pull the TLS host as the
|
||||
// authoritative source.
|
||||
func listenerInfoFromEntry(entry *inboundEntry) InboundListenerInfo {
|
||||
info := InboundListenerInfo{HTTPSPort: privateInboundPortHTTPS, HTTPPort: privateInboundPortHTTP}
|
||||
if entry.tlsListener != nil {
|
||||
host, port := splitHostPort(entry.tlsListener.Addr())
|
||||
info.TunnelIP = host
|
||||
if port != 0 {
|
||||
info.HTTPSPort = port
|
||||
}
|
||||
}
|
||||
if entry.plainListener != nil {
|
||||
host, port := splitHostPort(entry.plainListener.Addr())
|
||||
if info.TunnelIP == "" {
|
||||
info.TunnelIP = host
|
||||
}
|
||||
if port != 0 {
|
||||
info.HTTPPort = port
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// splitHostPort extracts host and port from a net.Addr, returning the
|
||||
// zero values when the address is missing or malformed.
|
||||
func splitHostPort(addr net.Addr) (string, uint16) {
|
||||
if addr == nil {
|
||||
return "", 0
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return "", 0
|
||||
}
|
||||
if portStr == "" {
|
||||
return host, 0
|
||||
}
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return host, 0
|
||||
}
|
||||
return host, uint16(port)
|
||||
}
|
||||
|
||||
// feedRouterFromListener accepts on the plain-HTTP netstack listener and
|
||||
// hands every connection to the account's router. The router peeks the
|
||||
// first byte and dispatches to the plain-HTTP channel for non-TLS
|
||||
// streams or the TLS channel for ClientHellos that arrive on :80.
|
||||
func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.Router, logger *log.Logger, accountID types.AccountID) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
|
||||
continue
|
||||
}
|
||||
router.HandleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// accountDialResolver returns a DialResolver bound to a single account's
|
||||
// embedded client. The router only ever serves traffic for that account
|
||||
// so the supplied accountID is ignored at dial time.
|
||||
func accountDialResolver(_ types.AccountID, client *embed.Client) nbtcp.DialResolver {
|
||||
return func(_ types.AccountID) (types.DialContextFunc, error) {
|
||||
return client.DialContext, nil
|
||||
}
|
||||
}
|
||||
|
||||
// accountTunnelLookup returns a TunnelLookupFunc backed by the embedded
|
||||
// client's peerstore for a single account. Phase 3 uses the result to
|
||||
// short-circuit ValidateTunnelPeer when the source IP is not in the
|
||||
// account's roster and to seed the cached identity for known peers.
|
||||
func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
return func(ip netip.Addr) (auth.PeerIdentity, bool) {
|
||||
pubKey, fqdn, ok := client.IdentityForIP(ip)
|
||||
if !ok {
|
||||
return auth.PeerIdentity{}, false
|
||||
}
|
||||
return auth.PeerIdentity{
|
||||
PubKey: pubKey,
|
||||
TunnelIP: ip,
|
||||
FQDN: fqdn,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
// withTunnelLookup returns an http.Handler that attaches the per-account
|
||||
// peerstore lookup to every request's context before delegating to next.
|
||||
// Calling on the host-level listener is a no-op because that path never
|
||||
// installs this wrapper, so the existing behaviour stays byte-for-byte
|
||||
// identical when --private-inbound is off or the request didn't arrive
|
||||
// on a per-account listener.
|
||||
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
|
||||
if lookup == nil {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := auth.WithTunnelLookup(r.Context(), lookup)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// inboundDebugAdapter adapts *inboundManager to the debug.InboundProvider
|
||||
// interface so the debug HTTP handler can render per-account inbound
|
||||
// listener state without importing the proxy package.
|
||||
type inboundDebugAdapter struct {
|
||||
mgr *inboundManager
|
||||
}
|
||||
|
||||
// InboundListeners returns a snapshot of the live per-account inbound
|
||||
// listeners formatted for the debug surface.
|
||||
func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.InboundListenerInfo {
|
||||
if a.mgr == nil {
|
||||
return nil
|
||||
}
|
||||
snap := a.mgr.Snapshot()
|
||||
if len(snap) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[types.AccountID]debug.InboundListenerInfo, len(snap))
|
||||
for id, info := range snap {
|
||||
out[id] = debug.InboundListenerInfo{
|
||||
TunnelIP: info.TunnelIP,
|
||||
HTTPSPort: info.HTTPSPort,
|
||||
HTTPPort: info.HTTPPort,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// newInboundErrorLog routes a per-account http.Server's stdlib error
|
||||
// stream through logrus at warn level.
|
||||
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
|
||||
return stdlog.New(logger.WithFields(log.Fields{
|
||||
"inbound-http": scheme,
|
||||
"account_id": accountID,
|
||||
}).WriterLevel(log.WarnLevel), "", 0)
|
||||
}
|
||||
@@ -1,502 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// bufioReader wraps the connection in a buffered reader so http.ReadResponse
|
||||
// can parse the response line + headers off the wire.
|
||||
func bufioReader(conn net.Conn) *bufio.Reader {
|
||||
return bufio.NewReader(conn)
|
||||
}
|
||||
|
||||
// quietLogger returns a logger that emits nothing — keeps test output tidy.
|
||||
func quietLogger() *log.Logger {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
return logger
|
||||
}
|
||||
|
||||
func TestInboundManager_RouteScopedToAccount(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
|
||||
accountA := types.AccountID("acct-a")
|
||||
accountB := types.AccountID("acct-b")
|
||||
|
||||
mgr.AddRoute(accountA, "shared.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountA, ServiceID: "svc-a", Domain: "shared.example"})
|
||||
mgr.AddRoute(accountB, "other.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: "other.example"})
|
||||
|
||||
require.Equal(t, 1, mgr.PendingRouteCount(accountA), "account A should have one queued route")
|
||||
require.Equal(t, 1, mgr.PendingRouteCount(accountB), "account B should have one queued route")
|
||||
|
||||
mgr.RemoveRoute(accountA, "shared.example", "svc-a")
|
||||
mgr.RemoveRoute(accountB, "other.example", "svc-b")
|
||||
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountA), "queue should drain on remove")
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountB), "queue should drain on remove")
|
||||
}
|
||||
|
||||
func TestInboundManager_PendingThenFlush(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
|
||||
accountID := types.AccountID("acct-1")
|
||||
host := nbtcp.SNIHost("example.test")
|
||||
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-1", Domain: "example.test"}
|
||||
|
||||
mgr.AddRoute(accountID, host, route)
|
||||
require.Equal(t, 1, mgr.PendingRouteCount(accountID), "pending count before listener is up")
|
||||
|
||||
// Simulate listener up by registering a fake entry, then flushing.
|
||||
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
entry := &inboundEntry{router: router}
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountID] = entry
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
mgr.flushPending(accountID, entry)
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "queue should be empty after flush")
|
||||
}
|
||||
|
||||
// fakeAddr is a stub net.Addr for tests that don't actually bind sockets.
|
||||
type fakeAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a *fakeAddr) Network() string { return "tcp" }
|
||||
func (a *fakeAddr) String() string { return a.addr }
|
||||
|
||||
// fakeMgmtClient implements roundtrip.managementClient for tests.
|
||||
type fakeMgmtClient struct{}
|
||||
|
||||
func (fakeMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// TestServer_PrivateInbound_NotEnabled_NoManager confirms that with
|
||||
// --private off the inbound manager is nil and the standalone proxy
|
||||
// keeps its zero-overhead default path.
|
||||
func TestServer_PrivateInbound_NotEnabled_NoManager(t *testing.T) {
|
||||
s := &Server{Logger: quietLogger(), Private: false}
|
||||
s.initPrivateInbound(http.NotFoundHandler(), nil)
|
||||
assert.Nil(t, s.inbound, "manager should remain nil when --private is off")
|
||||
}
|
||||
|
||||
// TestServer_PrivateInbound_Enabled_WiresLifecycle confirms that
|
||||
// --private alone wires the manager into the NetBird transport, so
|
||||
// AddPeer / RemovePeer drive the lifecycle.
|
||||
func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
|
||||
s := &Server{Logger: quietLogger(), Private: true}
|
||||
// Construct a NetBird transport. We can't actually start the embedded
|
||||
// client here (that needs a real management server), but we can
|
||||
// confirm that the lifecycle callbacks are registered.
|
||||
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
|
||||
MgmtAddr: "http://invalid.test",
|
||||
}, quietLogger(), nil, fakeMgmtClient{})
|
||||
|
||||
s.initPrivateInbound(http.NotFoundHandler(), &tls.Config{}) //nolint:gosec
|
||||
require.NotNil(t, s.inbound, "manager should be set when --private is on")
|
||||
assert.NotNil(t, s.inbound.handler, "handler should be set on manager")
|
||||
assert.NotNil(t, s.inbound.tlsConfig, "tls config should be set on manager")
|
||||
}
|
||||
|
||||
// TestInboundManager_AddRouteAfterReady_RegistersDirectly verifies that
|
||||
// when the listener is already up, AddRoute writes straight to the
|
||||
// router without queueing.
|
||||
func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
accountID := types.AccountID("acct-1")
|
||||
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountID] = &inboundEntry{router: router}
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
host := nbtcp.SNIHost("ready.example")
|
||||
mgr.AddRoute(accountID, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-ready", Domain: string(host)})
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "no pending entries when listener is up")
|
||||
}
|
||||
|
||||
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
|
||||
// bit reported upstream tracks --private exclusively. The previous
|
||||
// --private-inbound flag has been folded into --private.
|
||||
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
private bool
|
||||
expected bool
|
||||
}{
|
||||
{"off", false, false},
|
||||
{"on", true, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Server{Private: tt.private}
|
||||
assert.Equal(t, tt.expected, s.Private, "private capability bit should match --private")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInboundManager_RouteScopedToAccountB_DoesNotMatchA verifies that a
|
||||
// service registered for account B is invisible to a router serving
|
||||
// account A. We exercise the path through real per-account routers.
|
||||
func TestInboundManager_RouteScopedToAccountB_DoesNotMatchA(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
|
||||
accountA := types.AccountID("acct-a")
|
||||
accountB := types.AccountID("acct-b")
|
||||
routerA := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
routerB := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountA] = &inboundEntry{router: routerA}
|
||||
mgr.entries[accountB] = &inboundEntry{router: routerB}
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
host := nbtcp.SNIHost("shared.example")
|
||||
mgr.AddRoute(accountB, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: string(host)})
|
||||
|
||||
// Account A's router should have no routes; account B's should have one.
|
||||
// We check via IsEmpty — true means no routes and no fallback.
|
||||
assert.True(t, routerA.IsEmpty(), "account A router must not see account B's mappings")
|
||||
assert.False(t, routerB.IsEmpty(), "account B router should hold its own mapping")
|
||||
}
|
||||
|
||||
// TestInboundEntry_ShutdownIdempotent ensures that tearDown can run twice
|
||||
// without panicking — callers may invoke it from RemovePeer + StopAll.
|
||||
func TestInboundEntry_ShutdownIdempotent(t *testing.T) {
|
||||
t.Skip("teardown requires real netstack listeners; covered by integration tests")
|
||||
}
|
||||
|
||||
// TestRouter_PlainHTTP_ForwardedProtoIsHTTP exercises the full per-account
|
||||
// router pipeline against a loopback listener (proxy of a netstack
|
||||
// listener for test purposes): a plain HTTP request lands on the plain
|
||||
// http.Server and the inner handler observes a nil r.TLS, which is what
|
||||
// auth.ResolveProto translates to "http" in the real pipeline.
|
||||
func TestRouter_PlainHTTP_ForwardedProtoIsHTTP(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
|
||||
var captured atomic.Value
|
||||
captured.Store("")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil {
|
||||
captured.Store("http")
|
||||
} else {
|
||||
captured.Store("https")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
hostListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "loopback listener bind must succeed")
|
||||
defer hostListener.Close()
|
||||
|
||||
router := nbtcp.NewRouter(logger, nil, hostListener.Addr(), nbtcp.WithPlainHTTP(hostListener.Addr()))
|
||||
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
|
||||
defer func() { _ = httpServer.Close() }()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
|
||||
go func() { _ = router.Serve(ctx, hostListener) }()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", hostListener.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err, "plain HTTP dial must succeed")
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"))
|
||||
require.NoError(t, err, "write must succeed")
|
||||
|
||||
resp, err := http.ReadResponse(bufioReader(conn), nil)
|
||||
require.NoError(t, err, "must read response")
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "http", captured.Load(), "ForwardedProto must be http on plain path")
|
||||
}
|
||||
|
||||
// TestWithTunnelLookup_AttachesLookupToContext verifies that requests
|
||||
// flowing through the per-account handler wrapper carry the peerstore
|
||||
// lookup function. Phase 3's local-first deny path depends on this.
|
||||
func TestWithTunnelLookup_AttachesLookupToContext(t *testing.T) {
|
||||
expected := auth.PeerIdentity{TunnelIP: netip.MustParseAddr("100.64.0.10"), FQDN: "peer.netbird"}
|
||||
lookup := auth.TunnelLookupFunc(func(_ netip.Addr) (auth.PeerIdentity, bool) {
|
||||
return expected, true
|
||||
})
|
||||
|
||||
var observed auth.TunnelLookupFunc
|
||||
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
observed = auth.TunnelLookupFromContext(r.Context())
|
||||
})
|
||||
|
||||
handler := withTunnelLookup(inner, lookup)
|
||||
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), r)
|
||||
|
||||
require.NotNil(t, observed, "wrapper must inject the lookup into the request context")
|
||||
got, ok := observed(netip.MustParseAddr("100.64.0.10"))
|
||||
assert.True(t, ok, "lookup must round-trip through context")
|
||||
assert.Equal(t, expected.FQDN, got.FQDN, "lookup must return the same identity it was constructed with")
|
||||
}
|
||||
|
||||
// TestWithTunnelLookup_NilLookupIsNoop confirms the wrapper is a pure
|
||||
// pass-through when no lookup is provided. Required for the host-level
|
||||
// listener path to keep its byte-for-byte previous behaviour.
|
||||
func TestWithTunnelLookup_NilLookupIsNoop(t *testing.T) {
|
||||
var called bool
|
||||
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
assert.Nil(t, auth.TunnelLookupFromContext(r.Context()), "host-level path must not see a lookup function")
|
||||
})
|
||||
|
||||
handler := withTunnelLookup(inner, nil)
|
||||
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), r)
|
||||
assert.True(t, called, "wrapper without lookup must still invoke next")
|
||||
}
|
||||
|
||||
// fakeListener satisfies net.Listener for snapshot tests without binding
|
||||
// a real socket on the netstack.
|
||||
type fakeListener struct {
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func (f *fakeListener) Accept() (net.Conn, error) { return nil, net.ErrClosed }
|
||||
func (f *fakeListener) Close() error { return nil }
|
||||
func (f *fakeListener) Addr() net.Addr { return f.addr }
|
||||
|
||||
// TestInboundManager_ListenerInfo confirms ListenerInfo and Snapshot
|
||||
// surface the bound tunnel-IP and ports for live entries.
|
||||
func TestInboundManager_ListenerInfo(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
accountID := types.AccountID("acct-info")
|
||||
|
||||
tlsAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTPS}
|
||||
plainAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTP}
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountID] = &inboundEntry{
|
||||
tlsListener: &fakeListener{addr: tlsAddr},
|
||||
plainListener: &fakeListener{addr: plainAddr},
|
||||
}
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
info, ok := mgr.ListenerInfo(accountID)
|
||||
require.True(t, ok, "ListenerInfo must report ok for live entry")
|
||||
assert.Equal(t, "100.64.0.5", info.TunnelIP, "tunnel IP must come from listener address")
|
||||
assert.Equal(t, uint16(privateInboundPortHTTPS), info.HTTPSPort, "TLS port must match bound port")
|
||||
assert.Equal(t, uint16(privateInboundPortHTTP), info.HTTPPort, "HTTP port must match bound port")
|
||||
|
||||
snap := mgr.Snapshot()
|
||||
require.Len(t, snap, 1, "snapshot must contain exactly one entry")
|
||||
assert.Equal(t, info, snap[accountID], "snapshot entry must equal direct lookup")
|
||||
|
||||
_, ok = mgr.ListenerInfo(types.AccountID("missing"))
|
||||
assert.False(t, ok, "ListenerInfo must report ok=false for unknown accounts")
|
||||
}
|
||||
|
||||
// TestInboundManager_NilManagerSafe ensures the observability accessors
|
||||
// are safe to call when --private-inbound is off (nil manager).
|
||||
func TestInboundManager_NilManagerSafe(t *testing.T) {
|
||||
var mgr *inboundManager
|
||||
_, ok := mgr.ListenerInfo("anything")
|
||||
assert.False(t, ok, "nil manager must return ok=false")
|
||||
assert.Nil(t, mgr.Snapshot(), "nil manager must return nil snapshot")
|
||||
}
|
||||
|
||||
// TestInboundManager_ConcurrentAddRemove pounds AddRoute / RemoveRoute
|
||||
// from multiple goroutines to expose any locking gaps.
|
||||
func TestInboundManager_ConcurrentAddRemove(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
accountID := types.AccountID("acct-1")
|
||||
const workers = 32
|
||||
const iterations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
host := nbtcp.SNIHost("example.test")
|
||||
svc := types.ServiceID("svc")
|
||||
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svc, Domain: "example.test"}
|
||||
for j := 0; j < iterations; j++ {
|
||||
mgr.AddRoute(accountID, host, route)
|
||||
mgr.RemoveRoute(accountID, host, svc)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("concurrent add/remove timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_DeliversConnectionToHandler validates the
|
||||
// per-account inbound chain end-to-end with a loopback listener
|
||||
// substituted for the embedded netstack: a TCP connection arriving at
|
||||
// the plain listener flows through feedRouterFromListener, the router's
|
||||
// peek-and-dispatch, the wrapped HTTP server, and reaches the user
|
||||
// handler. If the embedded netstack is delivering connections at all,
|
||||
// this is the path they take. Failures localise to wiring bugs in the
|
||||
// proxy, not the netstack.
|
||||
func TestFeedRouterFromListener_DeliversConnectionToHandler(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
|
||||
hits := make(chan string, 1)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits <- r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("served"))
|
||||
})
|
||||
|
||||
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "plain loopback bind must succeed")
|
||||
t.Cleanup(func() { _ = plainLn.Close() })
|
||||
|
||||
router := nbtcp.NewRouter(logger, nil, &fakeAddr{addr: "127.0.0.1:0"}, nbtcp.WithPlainHTTP(plainLn.Addr()))
|
||||
|
||||
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
|
||||
t.Cleanup(func() { _ = httpServer.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
|
||||
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-1"))
|
||||
|
||||
conn, err := net.DialTimeout("tcp", plainLn.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err, "must connect to the plain listener")
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: app.example\r\nConnection: close\r\n\r\n"))
|
||||
require.NoError(t, err, "request write must succeed")
|
||||
|
||||
resp, err := http.ReadResponse(bufioReader(conn), nil)
|
||||
require.NoError(t, err, "must read response from server")
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "handler must be reached")
|
||||
|
||||
select {
|
||||
case host := <-hits:
|
||||
assert.Equal(t, "app.example", host, "handler must observe the request Host")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handler was not invoked — connection did not flow through router → http server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_DispatchesTLSToTLSChannel verifies that a
|
||||
// TLS ClientHello arriving on the plain listener is detected by the
|
||||
// router peek and re-dispatched to the TLS channel — the cross-channel
|
||||
// fallback the inbound stack relies on for HTTPS-on-:80 testing.
|
||||
func TestFeedRouterFromListener_DispatchesTLSToTLSChannel(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
|
||||
hits := make(chan string, 1)
|
||||
tlsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits <- r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("served-tls"))
|
||||
})
|
||||
|
||||
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "plain loopback bind must succeed")
|
||||
t.Cleanup(func() { _ = plainLn.Close() })
|
||||
|
||||
tlsLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "tls loopback bind must succeed")
|
||||
t.Cleanup(func() { _ = tlsLn.Close() })
|
||||
|
||||
router := nbtcp.NewRouter(logger, nil, tlsLn.Addr(), nbtcp.WithPlainHTTP(plainLn.Addr()))
|
||||
|
||||
tlsConfig := selfSignedTLSConfig(t)
|
||||
httpsServer := &http.Server{
|
||||
Handler: tlsHandler,
|
||||
TLSConfig: tlsConfig,
|
||||
ReadHeaderTimeout: time.Second,
|
||||
}
|
||||
t.Cleanup(func() { _ = httpsServer.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go func() { _ = httpsServer.ServeTLS(router.HTTPListener(), "", "") }()
|
||||
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-tls"))
|
||||
|
||||
tlsConn, err := tls.Dial("tcp", plainLn.Addr().String(), &tls.Config{InsecureSkipVerify: true}) //nolint:gosec
|
||||
require.NoError(t, err, "TLS dial against the plain listener must succeed (cross-channel)")
|
||||
t.Cleanup(func() { _ = tlsConn.Close() })
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://app.example/", nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, req.Write(tlsConn), "TLS request write must succeed")
|
||||
|
||||
resp, err := http.ReadResponse(bufioReader(tlsConn), req)
|
||||
require.NoError(t, err, "must read TLS response")
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "TLS handler must be reached")
|
||||
|
||||
select {
|
||||
case host := <-hits:
|
||||
assert.Equal(t, "app.example", host, "TLS handler must observe the request Host")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("TLS handler was not invoked — peek/dispatch path is broken")
|
||||
}
|
||||
}
|
||||
|
||||
func selfSignedTLSConfig(t *testing.T) *tls.Config {
|
||||
t.Helper()
|
||||
cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM)
|
||||
require.NoError(t, err, "load static self-signed cert")
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
|
||||
}
|
||||
|
||||
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
|
||||
// 127.0.0.1 — only used by tests that need a working TLS handshake.
|
||||
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
|
||||
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
|
||||
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
|
||||
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
|
||||
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
|
||||
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
|
||||
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
|
||||
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
|
||||
6MF9+Yw1Yy0t
|
||||
-----END CERTIFICATE-----`)
|
||||
var testKeyPEM = []byte(`-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
|
||||
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
|
||||
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
|
||||
-----END EC PRIVATE KEY-----`)
|
||||
@@ -1,47 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// PeerIdentity describes the locally-known facts about a peer reachable on
|
||||
// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP
|
||||
// and FQDN from the embedded client's peerstore. UserID, Email and Groups
|
||||
// stay zero in V1 — full identity still travels through ValidateTunnelPeer.
|
||||
// Phase V2 will populate them once RemotePeerConfig carries user identity.
|
||||
type PeerIdentity struct {
|
||||
PubKey string
|
||||
TunnelIP netip.Addr
|
||||
FQDN string
|
||||
|
||||
// V2 fields (zero in V1).
|
||||
UserID string
|
||||
Email string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally
|
||||
// available peerstore data. ok=false means the IP is not in the calling
|
||||
// account's roster.
|
||||
type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool)
|
||||
|
||||
type tunnelLookupContextKey struct{}
|
||||
|
||||
// WithTunnelLookup attaches a per-account peerstore lookup function to
|
||||
// the request context. The auth middleware calls this lookup before
|
||||
// hitting management's ValidateTunnelPeer to short-circuit unknown IPs
|
||||
// and to skip the RPC for already-cached identities.
|
||||
func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context {
|
||||
if lookup == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, tunnelLookupContextKey{}, lookup)
|
||||
}
|
||||
|
||||
// TunnelLookupFromContext returns the peerstore lookup attached to ctx,
|
||||
// or nil when the request did not arrive on a per-account listener.
|
||||
func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc {
|
||||
v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc)
|
||||
return v
|
||||
}
|
||||
@@ -36,7 +36,6 @@ type authenticator interface {
|
||||
// SessionValidator validates session tokens and checks user access permissions.
|
||||
type SessionValidator interface {
|
||||
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
|
||||
ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error)
|
||||
}
|
||||
|
||||
// Scheme defines an authentication mechanism for a domain.
|
||||
@@ -57,21 +56,12 @@ type DomainConfig struct {
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
IPRestrictions *restrict.Filter
|
||||
// Private routes the domain through ValidateTunnelPeer; failure → 403.
|
||||
Private bool
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
UserID string
|
||||
UserEmail string
|
||||
Valid bool
|
||||
DeniedReason string
|
||||
Groups []string
|
||||
// GroupNames carries the human-readable display names for Groups,
|
||||
// ordered identically (positional pairing). May be shorter than
|
||||
// Groups for tokens minted before names were embedded; the consumer
|
||||
// falls back to ids for missing positions.
|
||||
GroupNames []string
|
||||
}
|
||||
|
||||
// Middleware applies per-domain authentication and IP restriction checks.
|
||||
@@ -81,7 +71,6 @@ type Middleware struct {
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
geo restrict.GeoResolver
|
||||
tunnelCache *tunnelValidationCache
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware. The sessionValidator is
|
||||
@@ -95,7 +84,6 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
geo: geo,
|
||||
tunnelCache: newTunnelValidationCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,15 +111,6 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Private services bypass operator schemes and gate on tunnel peer.
|
||||
if config.Private {
|
||||
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Domains with no authentication schemes pass through after IP checks.
|
||||
if len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -150,54 +129,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.blockOIDCOnPlainHTTP(w, r, config) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
// requestIsPlainHTTP reports whether the request arrived without TLS.
|
||||
// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block.
|
||||
func requestIsPlainHTTP(r *http.Request) bool {
|
||||
return r.TLS == nil
|
||||
}
|
||||
|
||||
// hasOIDCScheme reports whether any of the configured schemes requires
|
||||
// TLS to round-trip safely with an external IdP.
|
||||
func hasOIDCScheme(schemes []Scheme) bool {
|
||||
for _, s := range schemes {
|
||||
if s.Type() == auth.MethodOIDC {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit
|
||||
// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing
|
||||
// the misconfiguration here yields a clearer error than the IdP's
|
||||
// "invalid redirect_uri" round-trip.
|
||||
func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
|
||||
if !requestIsPlainHTTP(r) {
|
||||
return false
|
||||
}
|
||||
if !hasOIDCScheme(config.Schemes) {
|
||||
return false
|
||||
}
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": r.Host,
|
||||
"remote": r.RemoteAddr,
|
||||
}).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 — use port 443")
|
||||
http.Error(w, "OIDC requires TLS — use port 443", http.StatusBadRequest)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
@@ -227,17 +162,7 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request
|
||||
return false
|
||||
}
|
||||
|
||||
var verdict restrict.Verdict
|
||||
if types.IsOverlayOrigin(r.Context()) {
|
||||
// Geo/CrowdSec checks don't apply over the WireGuard overlay:
|
||||
// the source address is always inside the NetBird CGNAT range,
|
||||
// which is never in a GeoIP database or a CrowdSec decision
|
||||
// list. Enforcing them here would either no-op (best case) or
|
||||
// fail-closed when the geo database is missing.
|
||||
verdict = config.IPRestrictions.CheckCIDR(clientIP)
|
||||
} else {
|
||||
verdict = config.IPRestrictions.Check(clientIP, mw.geo)
|
||||
}
|
||||
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
|
||||
if verdict == restrict.Allow {
|
||||
return true
|
||||
}
|
||||
@@ -321,111 +246,18 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetUserEmail(email)
|
||||
cd.SetUserGroups(groups)
|
||||
cd.SetUserGroupNames(groupNames)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the
|
||||
// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy
|
||||
// asks management to resolve it to a peer/user and to gate by the service's
|
||||
// distribution_groups. On success the proxy installs the freshly minted JWT
|
||||
// as a session cookie, sets UserID + Method=oidc on the captured data, and
|
||||
// forwards directly — operators see the same access-log shape as if the user
|
||||
// had completed an OIDC redirect. Any failure (private-range mismatch,
|
||||
// management unreachable, peer unknown, user not in group) returns false so
|
||||
// the caller falls back to the existing OIDC scheme dispatch.
|
||||
//
|
||||
// Phase 3 adds a local-first short-circuit: when the request arrived on a
|
||||
// per-account inbound listener the context carries a peerstore lookup
|
||||
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
|
||||
// roster the proxy denies fast without calling management. If the lookup
|
||||
// confirms a known peer the RPC still runs for the user-identity tail
|
||||
// (UserID + group access), but its result is cached for tunnelCacheTTL so
|
||||
// repeat requests skip management entirely.
|
||||
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
if mw.sessionValidator == nil {
|
||||
return false
|
||||
}
|
||||
clientIP := mw.resolveClientIP(r)
|
||||
if !clientIP.IsValid() {
|
||||
return false
|
||||
}
|
||||
if !isTunnelSourceIP(clientIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
|
||||
if _, ok := lookup(clientIP); !ok {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"remote": clientIP,
|
||||
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
|
||||
accountID: config.AccountID,
|
||||
tunnelIP: clientIP,
|
||||
domain: host,
|
||||
}, mw.validateTunnelPeer)
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC")
|
||||
return false
|
||||
}
|
||||
if !resp.GetValid() || resp.GetSessionToken() == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(resp.GetUserId())
|
||||
cd.SetUserEmail(resp.GetUserEmail())
|
||||
cd.SetUserGroups(resp.GetPeerGroupIds())
|
||||
cd.SetUserGroupNames(resp.GetPeerGroupNames())
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// validateTunnelPeer adapts the SessionValidator interface to the cache's
|
||||
// validateTunnelPeerFn signature.
|
||||
func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
return mw.sessionValidator.ValidateTunnelPeer(ctx, req)
|
||||
}
|
||||
|
||||
// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird
|
||||
// allocates tunnel addresses from by default. IsPrivate() doesn't include
|
||||
// it, so we check it explicitly.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
// isTunnelSourceIP reports whether ip falls within an address range typical
|
||||
// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10
|
||||
// (NetBird's default range). Loopback and link-local are excluded — the
|
||||
// fast-path is meant for peer-to-peer mesh traffic, not localhost.
|
||||
func isTunnelSourceIP(ip netip.Addr) bool {
|
||||
if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
|
||||
return false
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
return cgnatPrefix.Contains(ip)
|
||||
}
|
||||
|
||||
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
|
||||
// the request is forwarded directly (no redirect), which is important for API clients.
|
||||
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
@@ -454,7 +286,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
|
||||
if err != nil {
|
||||
setHeaderCapturedData(r.Context(), "", "", nil, nil)
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
@@ -466,7 +298,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames)
|
||||
setHeaderCapturedData(r.Context(), result.UserID)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
@@ -474,9 +306,6 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
}
|
||||
|
||||
@@ -486,7 +315,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
|
||||
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
|
||||
if errors.Is(err, ErrHeaderAuthFailed) {
|
||||
setHeaderCapturedData(r.Context(), "", "", nil, nil)
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
@@ -498,7 +327,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque
|
||||
return true
|
||||
}
|
||||
|
||||
func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) {
|
||||
func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
cd := proxy.CapturedDataFromContext(ctx)
|
||||
if cd == nil {
|
||||
return
|
||||
@@ -506,9 +335,6 @@ func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
cd.SetUserID(userID)
|
||||
cd.SetUserEmail(userEmail)
|
||||
cd.SetUserGroups(groups)
|
||||
cd.SetUserGroupNames(groupNames)
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
@@ -579,9 +405,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
@@ -596,9 +419,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
@@ -634,9 +454,12 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required.
|
||||
// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error {
|
||||
// AddDomain registers authentication schemes for the given domain.
|
||||
// If schemes are provided, a valid session public key is required to sign/verify
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
@@ -644,7 +467,6 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
Private: private,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -666,7 +488,6 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
Private: private,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -697,25 +518,18 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
}).Debug("Session validation denied")
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
UserEmail: resp.GetUserEmail(),
|
||||
Valid: false,
|
||||
DeniedReason: resp.DeniedReason,
|
||||
}, nil
|
||||
}
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
UserEmail: resp.GetUserEmail(),
|
||||
Valid: true,
|
||||
Groups: resp.GetPeerGroupIds(),
|
||||
GroupNames: resp.GetPeerGroupNames(),
|
||||
}, nil
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
}
|
||||
|
||||
userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil
|
||||
return &validationResult{UserID: userID, Valid: true}, nil
|
||||
}
|
||||
|
||||
// stripSessionTokenParam returns the request URI with the session_token query
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -64,7 +62,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -81,7 +79,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -95,7 +93,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
@@ -110,7 +108,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -123,7 +121,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -139,8 +137,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
@@ -156,7 +154,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
@@ -180,7 +178,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -197,7 +195,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -218,7 +216,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -239,9 +237,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
@@ -264,48 +262,15 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
assert.Equal(t, "authenticated", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
|
||||
// JWT's groups claim into CapturedData so policy-aware middlewares can
|
||||
// authorise without an extra management round-trip.
|
||||
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
groups := []string{"engineering", "sre"}
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd, "captured data must be present in request context")
|
||||
assert.Equal(t, "test-user", cd.GetUserID())
|
||||
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
|
||||
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -328,10 +293,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -355,10 +320,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -380,7 +345,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -392,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -445,7 +410,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -462,7 +427,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First scheme (PIN) always fails, second scheme (password) succeeds.
|
||||
@@ -481,7 +446,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -511,7 +476,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
return "invalid-jwt-token", "", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -535,7 +500,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
@@ -544,10 +509,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
@@ -571,7 +536,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -598,7 +563,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -625,7 +590,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -713,7 +678,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -749,7 +714,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -790,7 +755,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -805,69 +770,6 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
|
||||
}
|
||||
|
||||
// TestCheckIPRestrictions_OverlayOriginSkipsCountryRules covers the
|
||||
// inbound (WG) listener path: requests stamped with WithOverlayOrigin
|
||||
// must skip country lookups, even when no geo database is configured.
|
||||
// Without this short-circuit the inbound flow would fail-closed for
|
||||
// every overlay request whenever country rules are configured.
|
||||
func TestCheckIPRestrictions_OverlayOriginSkipsCountryRules(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{
|
||||
AllowedCIDRs: []string{"100.64.0.0/10"},
|
||||
AllowedCountries: []string{"US"},
|
||||
}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "100.64.5.6:5000"
|
||||
req.Host = "example.com"
|
||||
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code,
|
||||
"overlay-origin requests must not be denied by country rules they would fail without geo data")
|
||||
|
||||
// Sanity check: the same filter without the overlay flag denies (no geo,
|
||||
// country allowlist active → DenyGeoUnavailable).
|
||||
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req2.RemoteAddr = "100.64.5.6:5000"
|
||||
req2.Host = "example.com"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusForbidden, rr2.Code,
|
||||
"WAN-origin requests must still hit the full Check path and be denied without geo data")
|
||||
}
|
||||
|
||||
// TestCheckIPRestrictions_OverlayOriginRespectsCIDR confirms CIDR
|
||||
// rules still apply on the overlay path so operators retain a way to
|
||||
// scope private services to specific peer subnets.
|
||||
func TestCheckIPRestrictions_OverlayOriginRespectsCIDR(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"100.64.0.0/16"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "100.65.5.6:5000" // outside 100.64.0.0/16
|
||||
req.Host = "example.com"
|
||||
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code,
|
||||
"CIDR rules must still apply on the overlay path")
|
||||
}
|
||||
|
||||
func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
@@ -879,12 +781,11 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
|
||||
return "", oidcURL, nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -908,12 +809,11 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -934,7 +834,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
|
||||
// returns a signed session token when the expected header value is provided.
|
||||
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
@@ -952,7 +852,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
var backendCalled bool
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
@@ -995,7 +895,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
// Also add a PIN scheme so we can verify fallthrough behavior.
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -1015,7 +915,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -1038,7 +938,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
|
||||
return nil, errors.New("gRPC unavailable")
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -1055,7 +955,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -1106,7 +1006,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && accepted[ha.GetHeaderValue()] {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
@@ -1115,7 +1015,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
|
||||
// Single Header scheme (as if one entry existed), but the mock checks both values.
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
var backendCalled bool
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -1159,71 +1059,3 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
assert.False(t, backendCalled, "unknown token should be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
|
||||
// scheme is configured and the request arrived without TLS, the middleware
|
||||
// short-circuits with a 400 instead of dispatching to the IdP redirect.
|
||||
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodOIDC,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "https://idp.example.com/authorize", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
|
||||
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
|
||||
}
|
||||
|
||||
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
|
||||
// over TLS — the block only fires on plain HTTP.
|
||||
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodOIDC,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "https://idp.example.com/authorize", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
|
||||
}
|
||||
|
||||
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
|
||||
// block only fires when an OIDC scheme is configured. PIN-only domains
|
||||
// pass through normally on plain HTTP.
|
||||
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
|
||||
}
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is
|
||||
// reused before re-fetching from management. 5 minutes balances freshness
|
||||
// against management load on busy mesh networks.
|
||||
const tunnelCacheTTL = 300 * time.Second
|
||||
|
||||
// tunnelCachePerAccount caps the number of cached identities per account.
|
||||
// Bounded eviction avoids memory growth in pathological cases (huge peer
|
||||
// roster, brief request bursts) while staying generous for normal use.
|
||||
const tunnelCachePerAccount = 1024
|
||||
|
||||
// tunnelCacheKey identifies a cached entry by tunnel IP and originating
|
||||
// account. Domain is part of the value, not the key, because the
|
||||
// management response is per (account, IP) — domain only gates whether a
|
||||
// re-fetch is needed if the operator is accessing a different service.
|
||||
type tunnelCacheKey struct {
|
||||
accountID types.AccountID
|
||||
tunnelIP netip.Addr
|
||||
domain string
|
||||
}
|
||||
|
||||
// tunnelCacheEntry stores a positive validation response with the time it
|
||||
// was minted. Entries past tunnelCacheTTL are treated as misses.
|
||||
type tunnelCacheEntry struct {
|
||||
resp *proto.ValidateTunnelPeerResponse
|
||||
cachedAt time.Time
|
||||
}
|
||||
|
||||
// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by
|
||||
// (accountID, tunnelIP, domain). Only successful, valid responses are
|
||||
// cached — denials skip the cache so policy changes apply immediately.
|
||||
// Single-flight de-duplicates concurrent fetches for the same key so a
|
||||
// burst of cold requests collapses into a single RPC.
|
||||
type tunnelValidationCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[types.AccountID]*accountBucket
|
||||
flight singleflight.Group
|
||||
ttl time.Duration
|
||||
maxSize int
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// accountBucket holds the cached entries for a single account, with a
|
||||
// FIFO eviction queue used when the bucket exceeds maxSize.
|
||||
type accountBucket struct {
|
||||
items map[tunnelCacheKey]tunnelCacheEntry
|
||||
order []tunnelCacheKey
|
||||
}
|
||||
|
||||
// newTunnelValidationCache constructs a cache with default TTL and bounds.
|
||||
func newTunnelValidationCache() *tunnelValidationCache {
|
||||
return &tunnelValidationCache{
|
||||
entries: make(map[types.AccountID]*accountBucket),
|
||||
ttl: tunnelCacheTTL,
|
||||
maxSize: tunnelCachePerAccount,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// get returns a cached response for the key, or nil when missing or
|
||||
// expired. Expired entries are evicted lazily on read.
|
||||
func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
bucket, ok := c.entries[key.accountID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
entry, ok := bucket.items[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if c.now().Sub(entry.cachedAt) > c.ttl {
|
||||
delete(bucket.items, key)
|
||||
bucket.order = removeKey(bucket.order, key)
|
||||
return nil
|
||||
}
|
||||
return entry.resp
|
||||
}
|
||||
|
||||
// put records a positive response under the key. Evicts the oldest entry
|
||||
// in the account's bucket when the bound is exceeded.
|
||||
func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
bucket, ok := c.entries[key.accountID]
|
||||
if !ok {
|
||||
bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)}
|
||||
c.entries[key.accountID] = bucket
|
||||
}
|
||||
if _, exists := bucket.items[key]; !exists {
|
||||
bucket.order = append(bucket.order, key)
|
||||
}
|
||||
bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()}
|
||||
|
||||
for len(bucket.order) > c.maxSize {
|
||||
oldest := bucket.order[0]
|
||||
bucket.order = bucket.order[1:]
|
||||
delete(bucket.items, oldest)
|
||||
}
|
||||
}
|
||||
|
||||
// removeKey drops the first occurrence of needle from order. The cache
|
||||
// uses small slices so a linear scan is cheaper than a map+slice combo.
|
||||
func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey {
|
||||
for i, k := range order {
|
||||
if k == needle {
|
||||
return append(order[:i], order[i+1:]...)
|
||||
}
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
// flightKey turns a cache key into a single-flight string. AccountID and
|
||||
// IP isolation by themselves are insufficient because different domains
|
||||
// for the same peer/account may have different group access.
|
||||
func flightKey(key tunnelCacheKey) string {
|
||||
return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain
|
||||
}
|
||||
|
||||
// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches
|
||||
// the SessionValidator.ValidateTunnelPeer signature without exposing the
|
||||
// gRPC option variadic, since callers don't need it on the cache hot path.
|
||||
type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error)
|
||||
|
||||
// fetch returns a cached response when present, otherwise calls validate
|
||||
// under single-flight and caches the result. Denied responses pass
|
||||
// through but are not cached so policy changes apply immediately.
|
||||
func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) {
|
||||
if resp := c.get(key); resp != nil {
|
||||
return resp, true, nil
|
||||
}
|
||||
|
||||
flight := flightKey(key)
|
||||
res, err, _ := c.flight.Do(flight, func() (any, error) {
|
||||
if cached := c.get(key); cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{
|
||||
TunnelIp: key.tunnelIP.String(),
|
||||
Domain: key.domain,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.GetValid() && resp.GetSessionToken() != "" {
|
||||
c.put(key, resp)
|
||||
}
|
||||
return resp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
resp, _ := res.(*proto.ValidateTunnelPeerResponse)
|
||||
return resp, false, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user