mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-08 01:39:55 +00:00
Compare commits
267 Commits
v0.64.4
...
dn-reverse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76fb153d76 | ||
|
|
eee4d75932 | ||
|
|
62b8875f67 | ||
|
|
47a5478964 | ||
|
|
9922d6f953 | ||
|
|
f9bab22f61 | ||
|
|
3d8fdb7a89 | ||
|
|
fb10153ab8 | ||
|
|
57d3ee5aac | ||
|
|
cfdfdecc14 | ||
|
|
b00babb8b1 | ||
|
|
ac995bae6d | ||
|
|
41a5509ce0 | ||
|
|
db5e26db94 | ||
|
|
fe975fb834 | ||
|
|
e368d2995b | ||
|
|
a3241d8376 | ||
|
|
6dfc5772ba | ||
|
|
f70925178c | ||
|
|
9554934b92 | ||
|
|
7fdb824a37 | ||
|
|
412407adc0 | ||
|
|
e0874d7de7 | ||
|
|
8df1536cbb | ||
|
|
fcbacc62ec | ||
|
|
ee2ae45653 | ||
|
|
3bc8cbb13f | ||
|
|
bf7bdf6c4f | ||
|
|
6f2f0f9ae4 | ||
|
|
c37ebc6fb3 | ||
|
|
23abb5743c | ||
|
|
0a895ffc22 | ||
|
|
b87aa0bc15 | ||
|
|
69d4b5d821 | ||
|
|
f1a65d732d | ||
|
|
a3c0ea3e71 | ||
|
|
abaf061c2a | ||
|
|
e531fb54b1 | ||
|
|
5fcfed5b16 | ||
|
|
b81837a364 | ||
|
|
5f43449f67 | ||
|
|
6796601aa6 | ||
|
|
1fc25c301b | ||
|
|
08ae281b2d | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
bd47f44c63 | ||
|
|
381260911b | ||
|
|
38db42e7d6 | ||
|
|
5d606d909d | ||
|
|
d689718b50 | ||
|
|
54a73c6649 | ||
|
|
418377842e | ||
|
|
15ef56e03d | ||
|
|
917035f8e8 | ||
|
|
963e3f5457 | ||
|
|
e20b969188 | ||
|
|
1c7059ee67 | ||
|
|
22a3365658 | ||
|
|
2de1949018 | ||
|
|
08ab1e3478 | ||
|
|
ebb1f4007d | ||
|
|
acb53ece93 | ||
|
|
e020950cfd | ||
|
|
9dba262a20 | ||
|
|
5bcdf36377 | ||
|
|
1ffe8deb10 | ||
|
|
d069145bd1 | ||
|
|
f3493ee042 | ||
|
|
b782ac6f56 | ||
|
|
bf48044e5c | ||
|
|
fb4cc37a4a | ||
|
|
55b8d89a79 | ||
|
|
6968a32a5a | ||
|
|
cfe6753349 | ||
|
|
5ae15b3af3 | ||
|
|
b79adb706c | ||
|
|
f22497d5da | ||
|
|
95d672c9df | ||
|
|
7d08a609e6 | ||
|
|
eea6120cd0 | ||
|
|
fc88399c23 | ||
|
|
0cb02bd906 | ||
|
|
08d3867f41 | ||
|
|
b16d63643c | ||
|
|
940d01bdea | ||
|
|
ba9158d159 | ||
|
|
ca9a7e11ef | ||
|
|
a803f47685 | ||
|
|
79fed32f01 | ||
|
|
6b00bb0a66 | ||
|
|
e2adef1eea | ||
|
|
9e5fa11792 | ||
|
|
1ff75acb31 | ||
|
|
1754160686 | ||
|
|
423f6266fb | ||
|
|
16d1b4a14a | ||
|
|
7c14056faf | ||
|
|
62e37dc2e2 | ||
|
|
6a08695ee8 | ||
|
|
9a67a8e427 | ||
|
|
73aa0785ba | ||
|
|
53c1016a8e | ||
|
|
fd442138e6 | ||
|
|
be5f30225a | ||
|
|
7467e9fb8c | ||
|
|
2390c2e46e | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
778c223176 | ||
|
|
36cd0dd85c | ||
|
|
09a1d5a02d | ||
|
|
7c996ac9b5 | ||
|
|
cf9fd5d960 | ||
|
|
1c5ab7cb8f | ||
|
|
aaad3b25a7 | ||
|
|
9904235a2f | ||
|
|
780e9f57a5 | ||
|
|
a8db73285b | ||
|
|
3b43c00d12 | ||
|
|
2f390e1794 | ||
|
|
3630ebb3ae | ||
|
|
260c46df04 | ||
|
|
7f11e3205d | ||
|
|
1c8f92a96f | ||
|
|
7b6294b624 | ||
|
|
156d0b1fef | ||
|
|
2cf00dba58 | ||
|
|
d2a7f3ae36 | ||
|
|
6a64d4e4dd | ||
|
|
51e63c246b | ||
|
|
99e6b1eda4 | ||
|
|
dc26a5a436 | ||
|
|
3883b2fb41 | ||
|
|
ed58659a01 | ||
|
|
5190923c70 | ||
|
|
7c647dd160 | ||
|
|
07e59b2708 | ||
|
|
0a3a9f977d | ||
|
|
2f263bf7e6 | ||
|
|
f65f4fc280 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
adbd7ab4c3 | ||
|
|
0419834482 | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
f797d2d9cb | ||
|
|
5ae7efe8f7 | ||
|
|
d6e35bd0fe | ||
|
|
0e00f1c8f7 | ||
|
|
1b96648d4d | ||
|
|
4433f44a12 | ||
|
|
7504e718d7 | ||
|
|
9b0387e7ee | ||
|
|
d2f9653cea | ||
|
|
5ccce1ab3f | ||
|
|
e366fe340e | ||
|
|
b01809f8e3 | ||
|
|
790ef39187 | ||
|
|
3af16cf333 | ||
|
|
194a986926 | ||
|
|
d09c69f303 | ||
|
|
096d4ac529 | ||
|
|
f7732557fa | ||
|
|
8fafde614a | ||
|
|
694ae13418 | ||
|
|
b5b7dd4f53 | ||
|
|
476785b122 | ||
|
|
907677f835 | ||
|
|
7d844b9410 | ||
|
|
eeabc64a73 | ||
|
|
5da2b0fdcc | ||
|
|
a0005a604e | ||
|
|
a89bb807a6 | ||
|
|
28f3354ffa | ||
|
|
562923c600 | ||
|
|
d488f58311 | ||
|
|
0dd0c67b3b | ||
|
|
ca33849f31 | ||
|
|
18cd0f1480 | ||
|
|
b02982f6b1 | ||
|
|
4d89ae27ef | ||
|
|
733ea77c5c | ||
|
|
92f72bfce6 | ||
|
|
6fdc00ff41 | ||
|
|
bffb25bea7 | ||
|
|
3af4543e80 | ||
|
|
146774860b | ||
|
|
5243481316 | ||
|
|
76a39c1dcb | ||
|
|
02ce918114 | ||
|
|
30cfc22cb6 | ||
|
|
3168afbfcb | ||
|
|
a73ee47557 | ||
|
|
fa6ff005f2 | ||
|
|
095379fa60 | ||
|
|
30572fe1b8 | ||
|
|
b20d484972 | ||
|
|
3a6f364b03 | ||
|
|
5345d716ee | ||
|
|
f882c36e0a | ||
|
|
e95cfa1a00 | ||
|
|
0d480071b6 | ||
|
|
8e0b7b6c25 | ||
|
|
f204da0d68 | ||
|
|
7d74904d62 | ||
|
|
760ac5e07d | ||
|
|
4352228797 | ||
|
|
74c770609c | ||
|
|
f4ca36ed7e | ||
|
|
c86da92fc6 | ||
|
|
3f0c577456 | ||
|
|
717da8c7b7 | ||
|
|
a0a61d4f47 | ||
|
|
5b1fced872 | ||
|
|
c98dcf5ef9 | ||
|
|
57cb6bfccb | ||
|
|
95bf97dc3c | ||
|
|
3d116c9d33 | ||
|
|
a9ce9f8d5a | ||
|
|
10b981a855 | ||
|
|
7700b4333d | ||
|
|
7d0131111e | ||
|
|
1daea35e4b | ||
|
|
f97544af0d | ||
|
|
231e80cc15 | ||
|
|
a4c1362bff | ||
|
|
b611d4a751 | ||
|
|
2c9decfa55 | ||
|
|
3c5ac17e2f | ||
|
|
ae42bbb898 | ||
|
|
b86722394b | ||
|
|
a103f69767 | ||
|
|
73fbb3fc62 | ||
|
|
7b3523e25e | ||
|
|
6e4e1386e7 | ||
|
|
671e9af6eb | ||
|
|
50f42caf94 | ||
|
|
b7eeefc102 | ||
|
|
8dd22f3a4f | ||
|
|
4b89427447 | ||
|
|
b71e2860cf | ||
|
|
160b27bc60 | ||
|
|
c084386b88 | ||
|
|
6889047350 | ||
|
|
245bbb4acf | ||
|
|
2b2fc02d83 | ||
|
|
703ef29199 | ||
|
|
b0b60b938a | ||
|
|
e3a026bf1c | ||
|
|
94503465ee | ||
|
|
8d959b0abc | ||
|
|
1d8390b935 | ||
|
|
2851e38a1f | ||
|
|
51261fe7a9 | ||
|
|
304321d019 | ||
|
|
f8c3295645 | ||
|
|
183619d1e1 | ||
|
|
3b832d1f21 | ||
|
|
fcb849698f | ||
|
|
7527e0ebdb | ||
|
|
ed5f98da5b | ||
|
|
12b38e25da | ||
|
|
626e892e3b |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
*.key
|
||||||
|
*.crt
|
||||||
|
*.p12
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./signal/...
|
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
51
.github/workflows/golang-test-linux.yml
vendored
51
.github/workflows/golang-test-linux.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +204,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,6 +261,53 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
|
test_proxy:
|
||||||
|
name: "Proxy / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.0"
|
SIGN_PIPE_VER: "v0.1.1"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
|
!proxy/web/dist/
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### Self-Host NetBird (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
|
|||||||
@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
}
|
}
|
||||||
defer authClient.Close()
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
err, isAuthError := authClient.Login(ctx, "", "")
|
return fmt.Errorf("check login required: %v", err)
|
||||||
if isAuthError {
|
|
||||||
needsLogin = true
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("login check failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
|
|||||||
@@ -31,6 +31,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
|
PeerStatusConnected = peer.StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -71,6 +79,8 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
|
WireguardPort *int
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -140,6 +150,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
WireguardPort: opts.WireguardPort,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -159,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
|
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +192,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
|
||||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create auth client: %w", err)
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
@@ -189,10 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
|
||||||
c.recorder = recorder
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -345,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
recorder := c.recorder
|
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if recorder == nil {
|
|
||||||
return peer.FullStatus{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -360,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetFullStatus(), nil
|
return c.recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nftRule.Handle == 0 {
|
if nftRule.Handle == 0 {
|
||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback ipset counter
|
r.rollbackRules(pair)
|
||||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||||
|
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||||
|
keys := []string{
|
||||||
|
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
rule, ok := r.rules[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
}
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||||
|
// counters will be off until the next successful removal or refresh cycle.
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback set counter
|
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||||
func (r *router) refreshRulesMap() error {
|
func (r *router) refreshRulesMap() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
newRules := make(map[string]*nftables.Rule)
|
||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list rules: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||||
|
// preserve existing entries for this chain since we can't verify their state
|
||||||
|
for k, v := range r.rules {
|
||||||
|
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||||
|
newRules[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
r.rules[string(rule.UserData)] = rule
|
newRules[string(rule.UserData)] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
r.rules = newRules
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var needsFlush bool
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
if dnatRule.Handle == 0 {
|
||||||
|
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(masqRule); err != nil {
|
if masqRule.Handle == 0 {
|
||||||
|
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if needsFlush {
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
rule, exists := r.rules[ruleID]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
return nil
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Add a real rule to the kernel
|
||||||
|
ruleKey, err := r.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
firewall.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&firewall.Port{Values: []uint16{80}},
|
||||||
|
firewall.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
|
staleKey := "stale-rule-that-does-not-exist"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||||
|
|
||||||
|
err = r.refreshRulesMap()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||||
|
|
||||||
|
realRule, ok := r.rules[ruleKey.ID()]
|
||||||
|
assert.True(t, ok, "real rule should still exist after refresh")
|
||||||
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0
|
||||||
|
staleKey := "stale-route-rule"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule should not return an error for stale handles
|
||||||
|
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||||
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
pair := firewall.RouterPair{
|
||||||
|
ID: "staletest",
|
||||||
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
rtr := manager.router
|
||||||
|
|
||||||
|
// First add succeeds
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Corrupt the handle to simulate stale state
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||||
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the same rule again should succeed despite stale handles
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||||
|
|
||||||
|
// Verify rules exist in kernel
|
||||||
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
found := 0
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||||
|
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||||
|
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||||
|
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||||
|
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||||
|
if t.tombstone.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && !conn.IsSupersededBy(flags) {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsTombstone() {
|
if !exists || conn.IsSupersededBy(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||||
|
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||||
|
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||||
|
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||||
|
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||||
|
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and gracefully close a connection (server-initiated close)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Server sends FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Client sends FIN-ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Server sends final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Connection should be tombstoned
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn, "old connection should still be in map")
|
||||||
|
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||||
|
|
||||||
|
// Now reuse the same port for a new connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// The old tombstoned entry should be replaced with a new one
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
// SYN-ACK for the new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
|
||||||
|
// Data transfer should work
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||||
|
require.True(t, valid, "data should be allowed on new connection")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and RST a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||||
|
|
||||||
|
// Reuse the same port
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := srcIP
|
||||||
|
serverIP := dstIP
|
||||||
|
clientPort := srcPort
|
||||||
|
serverPort := dstPort
|
||||||
|
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||||
|
|
||||||
|
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Server-initiated close to reach Closed/tombstoned:
|
||||||
|
// Server FIN (opposite dir) → CloseWait
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
// Client FIN-ACK (same dir as conn) → LastAck
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// New inbound connection on same ports
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
|
||||||
|
// Complete handshake: server SYN-ACK, then client ACK
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// Late ACK should be rejected (tombstoned)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||||
|
|
||||||
|
// Late outbound ACK should not create a new connection (not a SYN)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Active close: client (outbound initiator) sends FIN first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Server ACKs the FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Server sends its own FIN
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||||
|
|
||||||
|
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||||
|
|
||||||
|
// SYN-ACK for new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish outbound connection and close via active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||||
|
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||||
|
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||||
|
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||||
|
|
||||||
|
// Simulate what the filter does next: TrackInbound via the normal path
|
||||||
|
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||||
|
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||||
|
newConn := tracker.connections[invertedKey]
|
||||||
|
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,11 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
@@ -24,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -89,6 +93,7 @@ type Manager struct {
|
|||||||
incomingDenyRules map[netip.Addr]RuleSet
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
|||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||||
|
return existingRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: string(ruleKey),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.routeRulesMap[ruleKey] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleKey := nbid.RuleID(rule.ID())
|
||||||
|
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == string(ruleKey)
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
delete(m.routeRulesMap, ruleKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
|
// Must be called with m.mutex held.
|
||||||
|
func (m *Manager) resetState() {
|
||||||
|
maps.Clear(m.outgoingRules)
|
||||||
|
maps.Clear(m.incomingDenyRules)
|
||||||
|
maps.Clear(m.incomingRules)
|
||||||
|
maps.Clear(m.routeRulesMap)
|
||||||
|
m.routeRules = m.routeRules[:0]
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
|
|||||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||||
|
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||||
|
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.1.0/24"),
|
||||||
|
netip.MustParsePrefix("100.64.2.0/24"),
|
||||||
|
}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add rule first time
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
// Add the same rule again
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
// These should be the same (idempotent) like nftables/iptables implementations
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||||
|
// different parameters get distinct IDs.
|
||||||
|
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
|
// Add first rule
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add different rule (different destination)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-2"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Different rules should have different IDs")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||||
|
// rule during a network map update does not disrupt existing traffic.
|
||||||
|
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
|
// Re-add same rule (simulates network map update)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||||
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
|
if rule1.ID() != rule2.ID() {
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.True(t, passAfter,
|
||||||
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
|
// returns the same rule without duplicating.
|
||||||
|
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call blockInvalidRouted directly multiple times
|
||||||
|
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule3)
|
||||||
|
|
||||||
|
// All should return the same rule
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||||
|
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||||
|
|
||||||
|
// Should have exactly 1 route rule
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||||
|
|
||||||
|
// Verify the rule blocks traffic to the WG network
|
||||||
|
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||||
|
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||||
|
// EnableRouting multiple times (as happens on each route update) does not
|
||||||
|
// accumulate duplicate block rules in the routeRules slice.
|
||||||
|
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount,
|
||||||
|
"Repeated EnableRouting should not accumulate block rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||||
|
// rule multiple times does not create duplicate entries.
|
||||||
|
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Simulate 5 network map updates with the same route rule
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||||
|
// after adding it multiple times works correctly.
|
||||||
|
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add same rule twice
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
|
// Delete using first reference
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify traffic no longer passes
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestManager(t *testing.T) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
|
}
|
||||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
|
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||||
|
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Add multiple deny rules for different ports
|
||||||
|
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
|
// Delete the first deny rule
|
||||||
|
err = m.DeletePeerRule(rule1[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount = len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
|
// Delete the second deny rule
|
||||||
|
err = m.DeletePeerRule(rule2[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, exists := m.incomingDenyRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
|
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||||
|
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
// Add a deny rule
|
||||||
|
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add an allow rule
|
||||||
|
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete them (simulating ACL manager cleanup)
|
||||||
|
for _, r := range rules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
for _, r := range allowRules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
|
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||||
|
// IP are stored in separate maps and don't interfere with each other.
|
||||||
|
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
|
// Add allow rule for port 80
|
||||||
|
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add deny rule for port 22
|
||||||
|
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
m.mutex.RLock()
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
|
// Delete allow rule should not affect deny rule
|
||||||
|
err = m.DeletePeerRule(allowRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
|
// Delete deny rule
|
||||||
|
err = m.DeletePeerRule(denyRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, denyExists := m.incomingDenyRules[addr]
|
||||||
|
_, allowExists := m.incomingRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
|
require.False(t, allowExists, "Allow rules should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,9 +18,18 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
logChannelSize = 1000
|
defaultLogChanSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getLogChannelSize() int {
|
||||||
|
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultLogChanSize
|
||||||
|
}
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -69,7 +80,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, logChannelSize),
|
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
|||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDeviceFilter constructor function
|
// newDeviceFilter constructor function
|
||||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying tun device exactly once.
|
||||||
|
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||||
|
// and multiple code paths can trigger Close on the same device.
|
||||||
|
func (d *FilteredDevice) Close() error {
|
||||||
|
var err error
|
||||||
|
d.closeOnce.Do(func() {
|
||||||
|
err = d.Device.Close()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Read wraps read method with filtering feature
|
// Read wraps read method with filtering feature
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
log.Debugf("failed to close tun device: %v", cErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nbnetstack.IsEnabled() {
|
||||||
|
return errors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.waitUntilRemoved(); err != nil {
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
if err := w.Destroy(); err != nil {
|
if err := w.Destroy(); err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nsTunDev, tunNet, nil
|
return t.tundev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
|||||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||||
|
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||||
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||||
|
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||||
|
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||||
|
// up when they're removed from the network map in a subsequent update.
|
||||||
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: add deny and accept rules
|
||||||
|
networkMap1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap1, false)
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||||
|
|
||||||
|
// Second update: remove the deny rule, keep only accept
|
||||||
|
networkMap2 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap2, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Should have 1 rule after removing deny rule")
|
||||||
|
|
||||||
|
// Third update: remove all rules
|
||||||
|
networkMap3 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{},
|
||||||
|
FirewallRulesIsEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap3, false)
|
||||||
|
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||||
|
"Should have 0 rules after removing all rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||||
|
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||||
|
// one added without leaking.
|
||||||
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: accept rule
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
|
// Second update: change to deny (same IP/port/proto, different action)
|
||||||
|
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Changing action should result in exactly 1 rule, not 2")
|
||||||
|
}
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -27,6 +29,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
// ProbeAvailability tests each upstream group's servers for availability
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
// and deactivates the group if no server responds
|
// and deactivates the group if no server responds
|
||||||
func (s *DefaultServer) ProbeAvailability() {
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
|
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||||
|
skipProbe, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||||
|
}
|
||||||
|
if skipProbe {
|
||||||
|
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, mux := range s.dnsMuxMap {
|
for _, mux := range s.dnsMuxMap {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|||||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
qname := strings.ToLower(question.Name)
|
||||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
// query doesn't match any configured domain
|
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||||
f.cache.set(domain, question.Qtype, result.IPs)
|
f.cache.set(qname, question.Qtype, result.IPs)
|
||||||
|
|
||||||
return resp
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
|
type udpResponseWriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
query *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||||
|
opt := u.query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ResponseWriter.WriteMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt := query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
// client advertised a larger EDNS0 buffer
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if our response is too big, truncate and set the TC bit
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
if resp != nil {
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
}
|
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else if resp != nil {
|
} else {
|
||||||
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2, "expected response to be written")
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2)
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
var writtenResp *dns.Msg
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writtenResp = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "Expected response to be written")
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||||
if resp != nil && writtenResp == nil {
|
|
||||||
writtenResp = resp
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
writeCalled := false
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writeCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
@@ -543,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
wgIfaceName := e.wgInterface.Name()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.triggerClientRestart()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -828,6 +830,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.Infof("sync finished in %s", time.Since(started))
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1017,7 +1023,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(state)
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
@@ -1918,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor {
|
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
|
|
||||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerInfo) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
|||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configMgr := sshconfig.New()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindListener is only used on Windows and JS platforms:
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
// - JS: Cannot listen to UDP sockets
|
// - JS: Cannot listen to UDP sockets
|
||||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
// gateway points to, preventing them from reaching the loopback interface.
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
// BindListener bypasses this by passing data directly through the bind.
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,11 +37,6 @@ func New() *NetworkMonitor {
|
|||||||
|
|
||||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
if netstack.IsEnabled() {
|
|
||||||
log.Debugf("Network monitor: skipping in netstack mode")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nw.mu.Lock()
|
nw.mu.Lock()
|
||||||
if nw.cancel != nil {
|
if nw.cancel != nil {
|
||||||
nw.mu.Unlock()
|
nw.mu.Unlock()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ThreadSafeAgent) Close() error {
|
|
||||||
var err error
|
|
||||||
a.once.Do(func() {
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- a.Agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if agent == nil {
|
||||||
|
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||||
|
}
|
||||||
|
|
||||||
return &ThreadSafeAgent{Agent: agent}, nil
|
return &ThreadSafeAgent{Agent: agent}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
// Defensive check to prevent nil pointer dereference
|
||||||
|
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||||
|
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||||
|
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||||
|
agent := a.Agent
|
||||||
|
if agent == nil {
|
||||||
|
log.Warnf("ICE agent is nil during close, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateICECredentials() (string, string, error) {
|
func GenerateICECredentials() (string, string, error) {
|
||||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if err := w.agent.Close(); err != nil {
|
if w.agent != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
sessionID, err := NewICESessionID()
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.AdminURL == nil {
|
if config.AdminURL == nil {
|
||||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
|
var once sync.Once
|
||||||
|
var wgIface *net.Interface
|
||||||
|
toInterface := func() *net.Interface {
|
||||||
|
once.Do(func() {
|
||||||
|
wgIface = m.wgInterface.ToInterface()
|
||||||
|
})
|
||||||
|
return wgIface
|
||||||
|
}
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_CLONING != 0 {
|
if flags&unix.RTF_CLONING != 0 {
|
||||||
flagStrs = append(flagStrs, "C")
|
flagStrs = append(flagStrs, "C")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_WASCLONED != 0 {
|
if flags&unix.RTF_WASCLONED != 0 {
|
||||||
flagStrs = append(flagStrs, "W")
|
flagStrs = append(flagStrs, "W")
|
||||||
}
|
}
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -4,17 +4,18 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -42,6 +42,7 @@ require (
|
|||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coder/websocket v1.8.13
|
github.com/coder/websocket v1.8.13
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
|
github.com/coreos/go-oidc/v3 v3.14.1
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||||
github.com/dexidp/dex/api/v2 v2.4.0
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
@@ -68,7 +69,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/oapi-codegen/runtime v1.1.2
|
github.com/oapi-codegen/runtime v1.1.2
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
@@ -167,7 +168,6 @@ require (
|
|||||||
github.com/containerd/containerd v1.7.29 // indirect
|
github.com/containerd/containerd v1.7.29 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/coreos/go-oidc/v3 v3.14.1 // indirect
|
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -91,16 +91,17 @@ read_reverse_proxy_type() {
|
|||||||
echo " [3] Nginx Proxy Manager (generates config + instructions)" > /dev/stderr
|
echo " [3] Nginx Proxy Manager (generates config + instructions)" > /dev/stderr
|
||||||
echo " [4] External Caddy (generates Caddyfile snippet)" > /dev/stderr
|
echo " [4] External Caddy (generates Caddyfile snippet)" > /dev/stderr
|
||||||
echo " [5] Other/Manual (displays setup documentation)" > /dev/stderr
|
echo " [5] Other/Manual (displays setup documentation)" > /dev/stderr
|
||||||
|
echo " [6] Traefik TCP Proxy (single port 443 + STUN)" > /dev/stderr
|
||||||
echo "" > /dev/stderr
|
echo "" > /dev/stderr
|
||||||
echo -n "Enter choice [0-5] (default: 0): " > /dev/stderr
|
echo -n "Enter choice [0-6] (default: 0): " > /dev/stderr
|
||||||
read -r CHOICE < /dev/tty
|
read -r CHOICE < /dev/tty
|
||||||
|
|
||||||
if [[ -z "$CHOICE" ]]; then
|
if [[ -z "$CHOICE" ]]; then
|
||||||
CHOICE="0"
|
CHOICE="0"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ! "$CHOICE" =~ ^[0-5]$ ]]; then
|
if [[ ! "$CHOICE" =~ ^[0-6]$ ]]; then
|
||||||
echo "Invalid choice. Please enter a number between 0 and 5." > /dev/stderr
|
echo "Invalid choice. Please enter a number between 0 and 6." > /dev/stderr
|
||||||
read_reverse_proxy_type
|
read_reverse_proxy_type
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
@@ -140,6 +141,35 @@ read_traefik_certresolver() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
read_traefik_tcp_acme_email() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
|
||||||
|
echo -n "Email address: " > /dev/stderr
|
||||||
|
read -r EMAIL < /dev/tty
|
||||||
|
if [[ -z "$EMAIL" ]]; then
|
||||||
|
echo "Email is required for Let's Encrypt." > /dev/stderr
|
||||||
|
read_traefik_tcp_acme_email
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "$EMAIL"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
read_enable_proxy() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Do you want to enable the NetBird Proxy service?" > /dev/stderr
|
||||||
|
echo "The proxy exposes internal NetBird network resources to the internet." > /dev/stderr
|
||||||
|
echo -n "Enable proxy? [y/N]: " > /dev/stderr
|
||||||
|
read -r CHOICE < /dev/tty
|
||||||
|
|
||||||
|
if [[ "$CHOICE" =~ ^[Yy]$ ]]; then
|
||||||
|
echo "true"
|
||||||
|
else
|
||||||
|
echo "false"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
read_port_binding_preference() {
|
read_port_binding_preference() {
|
||||||
echo "" > /dev/stderr
|
echo "" > /dev/stderr
|
||||||
echo "Should container ports be bound to localhost only (127.0.0.1)?" > /dev/stderr
|
echo "Should container ports be bound to localhost only (127.0.0.1)?" > /dev/stderr
|
||||||
@@ -206,6 +236,30 @@ wait_management() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wait_management_traefik() {
|
||||||
|
set +e
|
||||||
|
echo -n "Waiting for Management server to become ready"
|
||||||
|
counter=1
|
||||||
|
while true; do
|
||||||
|
# Check the embedded IdP endpoint through Traefik
|
||||||
|
if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2/.well-known/openid-configuration" 2>/dev/null; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [[ $counter -eq 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Taking too long. Checking logs..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 traefik
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 management
|
||||||
|
fi
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
set -e
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
wait_management_direct() {
|
wait_management_direct() {
|
||||||
set +e
|
set +e
|
||||||
local upstream_host=$(get_upstream_host)
|
local upstream_host=$(get_upstream_host)
|
||||||
@@ -246,10 +300,12 @@ initialize_default_values() {
|
|||||||
|
|
||||||
# Docker images
|
# Docker images
|
||||||
CADDY_IMAGE="caddy"
|
CADDY_IMAGE="caddy"
|
||||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
#DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||||
|
DASHBOARD_IMAGE="netbirdio/dashboard:pr-552"
|
||||||
SIGNAL_IMAGE="netbirdio/signal:latest"
|
SIGNAL_IMAGE="netbirdio/signal:latest"
|
||||||
RELAY_IMAGE="netbirdio/relay:latest"
|
RELAY_IMAGE="netbirdio/relay:latest"
|
||||||
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
||||||
|
PROXY_IMAGE=""
|
||||||
|
|
||||||
# Reverse proxy configuration
|
# Reverse proxy configuration
|
||||||
REVERSE_PROXY_TYPE="0"
|
REVERSE_PROXY_TYPE="0"
|
||||||
@@ -263,6 +319,14 @@ initialize_default_values() {
|
|||||||
RELAY_HOST_PORT="8084"
|
RELAY_HOST_PORT="8084"
|
||||||
BIND_LOCALHOST_ONLY="true"
|
BIND_LOCALHOST_ONLY="true"
|
||||||
EXTERNAL_PROXY_NETWORK=""
|
EXTERNAL_PROXY_NETWORK=""
|
||||||
|
|
||||||
|
# Traefik TCP proxy configuration
|
||||||
|
TRAEFIK_IMAGE="traefik:v3.6"
|
||||||
|
TRAEFIK_TCP_ACME_EMAIL=""
|
||||||
|
|
||||||
|
# NetBird Proxy configuration
|
||||||
|
ENABLE_PROXY="false"
|
||||||
|
PROXY_TOKEN=""
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,8 +357,17 @@ configure_reverse_proxy() {
|
|||||||
TRAEFIK_CERTRESOLVER=$(read_traefik_certresolver)
|
TRAEFIK_CERTRESOLVER=$(read_traefik_certresolver)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Handle Traefik TCP proxy prompts
|
||||||
|
if [[ "$REVERSE_PROXY_TYPE" == "6" ]]; then
|
||||||
|
TRAEFIK_TCP_ACME_EMAIL=$(read_traefik_tcp_acme_email)
|
||||||
|
|
||||||
|
# Prompt for NetBird Proxy configuration
|
||||||
|
ENABLE_PROXY=$(read_enable_proxy)
|
||||||
|
# Note: PROXY_TOKEN will be auto-generated after Management starts
|
||||||
|
fi
|
||||||
|
|
||||||
# Handle port binding for external proxy options (2-5)
|
# Handle port binding for external proxy options (2-5)
|
||||||
if [[ "$REVERSE_PROXY_TYPE" -ge 2 ]]; then
|
if [[ "$REVERSE_PROXY_TYPE" -ge 2 && "$REVERSE_PROXY_TYPE" -le 5 ]]; then
|
||||||
BIND_LOCALHOST_ONLY=$(read_port_binding_preference)
|
BIND_LOCALHOST_ONLY=$(read_port_binding_preference)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -313,7 +386,7 @@ check_existing_installation() {
|
|||||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||||
echo "You can use the following commands:"
|
echo "You can use the following commands:"
|
||||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||||
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt proxy.env"
|
||||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -347,6 +420,15 @@ generate_configuration_files() {
|
|||||||
5)
|
5)
|
||||||
render_docker_compose_exposed_ports > docker-compose.yml
|
render_docker_compose_exposed_ports > docker-compose.yml
|
||||||
;;
|
;;
|
||||||
|
6)
|
||||||
|
render_docker_compose_traefik_tcp > docker-compose.yml
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
# Create placeholder proxy.env so docker-compose can validate
|
||||||
|
# This will be overwritten with the actual token after Management starts
|
||||||
|
echo "# Placeholder - will be updated with token after Management starts" > proxy.env
|
||||||
|
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
||||||
|
fi
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Invalid reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
echo "Invalid reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
||||||
exit 1
|
exit 1
|
||||||
@@ -402,6 +484,50 @@ start_services_and_show_instructions() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "NetBird containers are running. Configure NPM as shown above, then access:"
|
echo "NetBird containers are running. Configure NPM as shown above, then access:"
|
||||||
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
|
elif [[ "$REVERSE_PROXY_TYPE" == "6" ]]; then
|
||||||
|
# Traefik TCP Proxy - two-phase startup if proxy is enabled
|
||||||
|
echo -e "$MSG_STARTING_SERVICES"
|
||||||
|
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
# Phase 1: Start core services (without proxy)
|
||||||
|
echo "Starting core services..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d traefik dashboard signal relay management
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
wait_management_traefik
|
||||||
|
|
||||||
|
# Phase 2: Create proxy token and start proxy
|
||||||
|
echo ""
|
||||||
|
echo "Creating proxy access token..."
|
||||||
|
# Use docker exec with bash to run the token command directly
|
||||||
|
# (bypassing the entrypoint which adds 'management' as first arg)
|
||||||
|
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T management \
|
||||||
|
bash -c '/go/bin/netbird-mgmt token create --name "default-proxy" --config /etc/netbird/management.json' 2>/dev/null | grep "^Token:" | awk '{print $2}')
|
||||||
|
|
||||||
|
if [[ -z "$PROXY_TOKEN" ]]; then
|
||||||
|
echo "ERROR: Failed to create proxy token. Check management logs." > /dev/stderr
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 management
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Proxy token created successfully."
|
||||||
|
|
||||||
|
# Generate proxy.env with the token
|
||||||
|
render_proxy_env > proxy.env
|
||||||
|
|
||||||
|
# Start proxy service
|
||||||
|
echo "Starting proxy service..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d proxy
|
||||||
|
else
|
||||||
|
# No proxy - start all services at once
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
wait_management_traefik
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "$MSG_DONE"
|
||||||
|
print_post_setup_instructions
|
||||||
else
|
else
|
||||||
# External proxies (nginx, external Caddy, other) - need manual config first
|
# External proxies (nginx, external Caddy, other) - need manual config first
|
||||||
print_post_setup_instructions
|
print_post_setup_instructions
|
||||||
@@ -547,6 +673,29 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
render_proxy_env() {
|
||||||
|
cat <<EOF
|
||||||
|
# NetBird Proxy Configuration
|
||||||
|
NB_PROXY_DEBUG_LOGS=false
|
||||||
|
# Use internal Docker network to connect to management (avoids hairpin NAT issues)
|
||||||
|
NB_PROXY_MANAGEMENT_ADDRESS=http://management:80
|
||||||
|
# Allow insecure gRPC connection to management (required for internal Docker network)
|
||||||
|
NB_PROXY_ALLOW_INSECURE=true
|
||||||
|
# Public URL where this proxy is reachable (used for cluster registration)
|
||||||
|
NB_PROXY_DOMAIN=$NETBIRD_DOMAIN
|
||||||
|
NB_PROXY_ADDRESS=:8443
|
||||||
|
NB_PROXY_TOKEN=$PROXY_TOKEN
|
||||||
|
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
|
||||||
|
NB_PROXY_ACME_CERTIFICATES=true
|
||||||
|
NB_PROXY_ACME_CHALLENGE_TYPE=tls-alpn-01
|
||||||
|
NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
||||||
|
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||||
|
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||||
|
NB_PROXY_FORWARDED_PROTO=https
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
render_docker_compose() {
|
render_docker_compose() {
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
services:
|
services:
|
||||||
@@ -736,7 +885,18 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-rel
|
|||||||
|
|
||||||
# Management (includes embedded IdP)
|
# Management (includes embedded IdP)
|
||||||
management:
|
management:
|
||||||
|
$(if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
cat <<MGMT_BUILD
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: management/Dockerfile.multistage
|
||||||
|
pull_policy: build
|
||||||
|
MGMT_BUILD
|
||||||
|
else
|
||||||
|
cat <<MGMT_IMAGE
|
||||||
image: $MANAGEMENT_IMAGE
|
image: $MANAGEMENT_IMAGE
|
||||||
|
MGMT_IMAGE
|
||||||
|
fi)
|
||||||
container_name: netbird-management
|
container_name: netbird-management
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [$network_name]
|
networks: [$network_name]
|
||||||
@@ -1115,6 +1275,258 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
render_docker_compose_traefik_tcp() {
|
||||||
|
# Generate proxy service section if enabled
|
||||||
|
local proxy_service=""
|
||||||
|
local proxy_volumes=""
|
||||||
|
local proxy_tcp_labels=""
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
proxy_service="
|
||||||
|
# NetBird Proxy - exposes internal resources to the internet
|
||||||
|
proxy:
|
||||||
|
build:
|
||||||
|
context: ../
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
# Always rebuild to pick up code changes during testing
|
||||||
|
pull_policy: build
|
||||||
|
#image: $PROXY_IMAGE
|
||||||
|
container_name: netbird-proxy
|
||||||
|
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||||
|
extra_hosts:
|
||||||
|
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
- signal
|
||||||
|
env_file:
|
||||||
|
- ./proxy.env
|
||||||
|
volumes:
|
||||||
|
- netbird_proxy_certs:/certs
|
||||||
|
labels:
|
||||||
|
# TCP passthrough for any unmatched domain (proxy handles its own TLS)
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.entrypoints=websecure
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.rule=HostSNI(\`*\`)
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.tls.passthrough=true
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.priority=1
|
||||||
|
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
||||||
|
logging:
|
||||||
|
driver: \"json-file\"
|
||||||
|
options:
|
||||||
|
max-size: \"500m\"
|
||||||
|
max-file: \"2\"
|
||||||
|
"
|
||||||
|
proxy_volumes="
|
||||||
|
netbird_proxy_certs:"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
services:
|
||||||
|
# Traefik - single port 443 entry point with TLS termination
|
||||||
|
traefik:
|
||||||
|
image: $TRAEFIK_IMAGE
|
||||||
|
container_name: netbird-traefik
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
ipv4_address: 172.30.0.10
|
||||||
|
ports:
|
||||||
|
- '443:443'
|
||||||
|
volumes:
|
||||||
|
- netbird_traefik_data:/data
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
|
command:
|
||||||
|
# Logging
|
||||||
|
- --log.level=INFO
|
||||||
|
- --accesslog=true
|
||||||
|
# Docker provider
|
||||||
|
- --providers.docker=true
|
||||||
|
- --providers.docker.exposedbydefault=false
|
||||||
|
- --providers.docker.network=netbird
|
||||||
|
# Entrypoints
|
||||||
|
- --entrypoints.websecure.address=:443
|
||||||
|
- --entrypoints.websecure.allowACMEByPass=true
|
||||||
|
# Disable timeouts for long-lived gRPC streams
|
||||||
|
- --entrypoints.websecure.transport.respondingtimeouts.readtimeout=0s
|
||||||
|
- --entrypoints.websecure.transport.respondingtimeouts.writetimeout=0s
|
||||||
|
- --entrypoints.websecure.transport.respondingtimeouts.idletimeout=0s
|
||||||
|
# Let's Encrypt ACME
|
||||||
|
- --certificatesresolvers.letsencrypt.acme.email=$TRAEFIK_TCP_ACME_EMAIL
|
||||||
|
- --certificatesresolvers.letsencrypt.acme.storage=/data/acme.json
|
||||||
|
- --certificatesresolvers.letsencrypt.acme.tlschallenge=true
|
||||||
|
# gRPC transport settings (disable response timeout for long-lived streams)
|
||||||
|
- --serverstransport.forwardingtimeouts.responseheadertimeout=0s
|
||||||
|
- --serverstransport.forwardingtimeouts.idleconntimeout=0s
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# UI dashboard
|
||||||
|
dashboard:
|
||||||
|
image: $DASHBOARD_IMAGE
|
||||||
|
container_name: netbird-dashboard
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
env_file:
|
||||||
|
- ./dashboard.env
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.netbird-dashboard.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-dashboard.rule=Host(\`$NETBIRD_DOMAIN\`)
|
||||||
|
- traefik.http.routers.netbird-dashboard.tls=true
|
||||||
|
- traefik.http.routers.netbird-dashboard.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-dashboard.service=dashboard
|
||||||
|
- traefik.http.routers.netbird-dashboard.priority=1
|
||||||
|
- traefik.http.services.dashboard.loadbalancer.server.port=80
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Signal
|
||||||
|
signal:
|
||||||
|
image: $SIGNAL_IMAGE
|
||||||
|
container_name: netbird-signal
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
# Signal WebSocket
|
||||||
|
- traefik.http.routers.netbird-signal-ws.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-signal-ws.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/ws-proxy/signal\`)
|
||||||
|
- traefik.http.routers.netbird-signal-ws.tls=true
|
||||||
|
- traefik.http.routers.netbird-signal-ws.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-signal-ws.service=signal-ws
|
||||||
|
- traefik.http.routers.netbird-signal-ws.priority=100
|
||||||
|
- traefik.http.services.signal-ws.loadbalancer.server.port=80
|
||||||
|
# Signal gRPC
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/signalexchange.SignalExchange/\`)
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.tls=true
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.service=signal-grpc
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.priority=100
|
||||||
|
- traefik.http.services.signal-grpc.loadbalancer.server.port=10000
|
||||||
|
- traefik.http.services.signal-grpc.loadbalancer.server.scheme=h2c
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Relay (includes embedded STUN server)
|
||||||
|
relay:
|
||||||
|
image: $RELAY_IMAGE
|
||||||
|
container_name: netbird-relay
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
ports:
|
||||||
|
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
|
||||||
|
env_file:
|
||||||
|
- ./relay.env
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.netbird-relay.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-relay.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/relay\`)
|
||||||
|
- traefik.http.routers.netbird-relay.tls=true
|
||||||
|
- traefik.http.routers.netbird-relay.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-relay.service=relay
|
||||||
|
- traefik.http.routers.netbird-relay.priority=100
|
||||||
|
- traefik.http.services.relay.loadbalancer.server.port=80
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Management (includes embedded IdP)
|
||||||
|
management:
|
||||||
|
$(if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
cat <<MGMT_BUILD
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: management/Dockerfile.multistage
|
||||||
|
pull_policy: build
|
||||||
|
MGMT_BUILD
|
||||||
|
else
|
||||||
|
cat <<MGMT_IMAGE
|
||||||
|
image: $MANAGEMENT_IMAGE
|
||||||
|
MGMT_IMAGE
|
||||||
|
fi)
|
||||||
|
container_name: netbird-management
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
volumes:
|
||||||
|
- netbird_management:/var/lib/netbird
|
||||||
|
- ./management.json:/etc/netbird/management.json
|
||||||
|
command: [
|
||||||
|
"--port", "80",
|
||||||
|
"--log-file", "console",
|
||||||
|
"--log-level", "info",
|
||||||
|
"--disable-anonymous-metrics=false",
|
||||||
|
"--single-account-mode-domain=netbird.selfhosted",
|
||||||
|
"--dns-domain=netbird.selfhosted",
|
||||||
|
"--idp-sign-key-refresh-enabled",
|
||||||
|
]
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
# Management API
|
||||||
|
- traefik.http.routers.netbird-api.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-api.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/api\`)
|
||||||
|
- traefik.http.routers.netbird-api.tls=true
|
||||||
|
- traefik.http.routers.netbird-api.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-api.service=management
|
||||||
|
- traefik.http.routers.netbird-api.priority=100
|
||||||
|
# Management WebSocket
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/ws-proxy/management\`)
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.tls=true
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.service=management
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.priority=100
|
||||||
|
# Management gRPC
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/management.ManagementService/\`)
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.tls=true
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.service=management-grpc
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.priority=100
|
||||||
|
# OAuth2 (embedded IdP)
|
||||||
|
- traefik.http.routers.netbird-oauth2.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-oauth2.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/oauth2\`)
|
||||||
|
- traefik.http.routers.netbird-oauth2.tls=true
|
||||||
|
- traefik.http.routers.netbird-oauth2.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-oauth2.service=management
|
||||||
|
- traefik.http.routers.netbird-oauth2.priority=100
|
||||||
|
# Services
|
||||||
|
- traefik.http.services.management.loadbalancer.server.port=80
|
||||||
|
- traefik.http.services.management-grpc.loadbalancer.server.port=80
|
||||||
|
- traefik.http.services.management-grpc.loadbalancer.server.scheme=h2c
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
${proxy_service}
|
||||||
|
volumes:
|
||||||
|
netbird_traefik_data:
|
||||||
|
netbird_management:${proxy_volumes}
|
||||||
|
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
driver: bridge
|
||||||
|
ipam:
|
||||||
|
config:
|
||||||
|
- subnet: 172.30.0.0/24
|
||||||
|
gateway: 172.30.0.1
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
render_npm_advanced_config() {
|
render_npm_advanced_config() {
|
||||||
local upstream_host=$(get_upstream_host)
|
local upstream_host=$(get_upstream_host)
|
||||||
local relay_addr="${upstream_host}:${RELAY_HOST_PORT}"
|
local relay_addr="${upstream_host}:${RELAY_HOST_PORT}"
|
||||||
@@ -1424,6 +1836,36 @@ print_manual_instructions() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print_traefik_tcp_instructions() {
|
||||||
|
echo ""
|
||||||
|
echo "$MSG_SEPARATOR"
|
||||||
|
echo " TRAEFIK TCP PROXY SETUP"
|
||||||
|
echo "$MSG_SEPARATOR"
|
||||||
|
echo ""
|
||||||
|
echo "This configuration uses Traefik as a single entry point on port 443."
|
||||||
|
echo "Traefik handles TLS termination with Let's Encrypt and routes to services."
|
||||||
|
echo ""
|
||||||
|
echo "Open ports:"
|
||||||
|
echo " - 443/tcp (HTTPS - all NetBird services)"
|
||||||
|
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
||||||
|
echo ""
|
||||||
|
echo "Generated files:"
|
||||||
|
echo " - docker-compose.yml (container definitions with Traefik labels)"
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
echo " - proxy.env (NetBird Proxy configuration)"
|
||||||
|
echo ""
|
||||||
|
echo "NetBird Proxy:"
|
||||||
|
echo " The proxy service is enabled and will be built from source."
|
||||||
|
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||||
|
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
||||||
|
echo " Point your proxy domains (CNAMEs) to this server's IP address."
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
|
echo "Follow the onboarding steps to set up your NetBird instance."
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
print_post_setup_instructions() {
|
print_post_setup_instructions() {
|
||||||
case "$REVERSE_PROXY_TYPE" in
|
case "$REVERSE_PROXY_TYPE" in
|
||||||
0)
|
0)
|
||||||
@@ -1444,6 +1886,9 @@ print_post_setup_instructions() {
|
|||||||
5)
|
5)
|
||||||
print_manual_instructions
|
print_manual_instructions
|
||||||
;;
|
;;
|
||||||
|
6)
|
||||||
|
print_traefik_tcp_instructions
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
echo "Unknown reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
||||||
;;
|
;;
|
||||||
|
|||||||
17
management/Dockerfile.multistage
Normal file
17
management/Dockerfile.multistage
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||||
|
CMD ["--log-file", "console"]
|
||||||
|
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt
|
||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/server"
|
"github.com/netbirdio/netbird/management/internals/server"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -213,11 +215,14 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
|||||||
// Set HttpConfig values from EmbeddedIdP
|
// Set HttpConfig values from EmbeddedIdP
|
||||||
cfg.HttpConfig.AuthIssuer = issuer
|
cfg.HttpConfig.AuthIssuer = issuer
|
||||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
|
||||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||||
|
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
|
||||||
|
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,4 +80,10 @@ func init() {
|
|||||||
migrationCmd.AddCommand(upCmd)
|
migrationCmd.AddCommand(upCmd)
|
||||||
|
|
||||||
rootCmd.AddCommand(migrationCmd)
|
rootCmd.AddCommand(migrationCmd)
|
||||||
|
|
||||||
|
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||||
|
tokenCmd.AddCommand(tokenCreateCmd)
|
||||||
|
tokenCmd.AddCommand(tokenListCmd)
|
||||||
|
tokenCmd.AddCommand(tokenRevokeCmd)
|
||||||
|
rootCmd.AddCommand(tokenCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
209
management/cmd/token.go
Normal file
209
management/cmd/token.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
tokenName string
|
||||||
|
tokenExpireIn string
|
||||||
|
tokenDatadir string
|
||||||
|
|
||||||
|
tokenCmd = &cobra.Command{
|
||||||
|
Use: "token",
|
||||||
|
Short: "Manage proxy access tokens",
|
||||||
|
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCreateCmd = &cobra.Command{
|
||||||
|
Use: "create",
|
||||||
|
Short: "Create a new proxy access token",
|
||||||
|
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||||
|
RunE: tokenCreateRun,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List all proxy access tokens",
|
||||||
|
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||||
|
RunE: tokenListRun,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenRevokeCmd = &cobra.Command{
|
||||||
|
Use: "revoke [token-id]",
|
||||||
|
Short: "Revoke a proxy access token",
|
||||||
|
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: tokenRevokeRun,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||||
|
|
||||||
|
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||||
|
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||||
|
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||||
|
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
||||||
|
|
||||||
|
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := config.Datadir
|
||||||
|
if tokenDatadir != "" {
|
||||||
|
datadir = tokenDatadir
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := s.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
||||||
|
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
expiresIn, err := parseDuration(tokenExpireIn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse expiration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||||
|
return fmt.Errorf("save token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Token created successfully!") //nolint:forbidigo
|
||||||
|
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
|
||||||
|
fmt.Println() //nolint:forbidigo
|
||||||
|
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
|
||||||
|
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
||||||
|
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||||
|
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
expires := "never"
|
||||||
|
if t.ExpiresAt != nil {
|
||||||
|
expires = t.ExpiresAt.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
|
||||||
|
lastUsed := "never"
|
||||||
|
if t.LastUsed != nil {
|
||||||
|
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||||
|
}
|
||||||
|
|
||||||
|
revoked := "no"
|
||||||
|
if t.Revoked {
|
||||||
|
revoked = "yes"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||||
|
t.ID,
|
||||||
|
t.Name,
|
||||||
|
t.CreatedAt.Format("2006-01-02"),
|
||||||
|
expires,
|
||||||
|
lastUsed,
|
||||||
|
revoked,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Flush()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
||||||
|
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
tokenID := args[0]
|
||||||
|
|
||||||
|
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||||
|
return fmt.Errorf("revoke token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||||
|
// An empty string returns zero duration (no expiration).
|
||||||
|
func parseDuration(s string) (time.Duration, error) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s[len(s)-1] == 'd' {
|
||||||
|
d, err := strconv.Atoi(s[:len(s)-1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return time.Duration(d) * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
101
management/cmd/token_test.go
Normal file
101
management/cmd/token_test.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected time.Duration
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string returns zero",
|
||||||
|
input: "",
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "days suffix",
|
||||||
|
input: "30d",
|
||||||
|
expected: 30 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one day",
|
||||||
|
input: "1d",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "365 days",
|
||||||
|
input: "365d",
|
||||||
|
expected: 365 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hours via Go duration",
|
||||||
|
input: "24h",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "minutes via Go duration",
|
||||||
|
input: "30m",
|
||||||
|
expected: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex Go duration",
|
||||||
|
input: "1h30m",
|
||||||
|
expected: 90 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid day format",
|
||||||
|
input: "abcd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative days",
|
||||||
|
input: "-1d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero days",
|
||||||
|
input: "0d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-numeric days",
|
||||||
|
input: "xyzd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative Go duration",
|
||||||
|
input: "-24h",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero Go duration",
|
||||||
|
input: "0s",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid Go duration",
|
||||||
|
input: "notaduration",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := parseDuration(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -174,6 +174,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
semaphore := make(chan struct{}, 10)
|
semaphore := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -247,7 +248,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +327,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -370,7 +375,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -435,6 +443,8 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
@@ -778,6 +788,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
@@ -840,6 +851,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
} else {
|
} else {
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|||||||
@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
|
|||||||
func TestSendUpdate(t *testing.T) {
|
func TestSendUpdate(t *testing.T) {
|
||||||
peer := "test-sendupdate"
|
peer := "test-sendupdate"
|
||||||
peersUpdater := NewPeersUpdateManager(nil)
|
peersUpdater := NewPeersUpdateManager(nil)
|
||||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update1 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 0,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||||
}
|
}
|
||||||
|
|
||||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update2 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 10,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
||||||
timeout := time.After(5 * time.Second)
|
timeout := time.After(5 * time.Second)
|
||||||
|
|||||||
@@ -4,6 +4,19 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MessageType indicates the type of update message for debouncing strategy
|
||||||
|
type MessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
|
||||||
|
// These updates can be safely debounced - only the latest state matters
|
||||||
|
MessageTypeNetworkMap MessageType = iota
|
||||||
|
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
|
||||||
|
// These updates should not be dropped as they contain time-sensitive information
|
||||||
|
MessageTypeControlConfig
|
||||||
|
)
|
||||||
|
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
|
MessageType MessageType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -32,6 +33,7 @@ type Manager interface {
|
|||||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
|
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -182,3 +184,36 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
|
if err == nil && existingPeerID != "" {
|
||||||
|
// Peer already exists
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
||||||
|
peer := &peer.Peer{
|
||||||
|
Ephemeral: true,
|
||||||
|
ProxyMeta: peer.ProxyMeta{
|
||||||
|
Cluster: cluster,
|
||||||
|
Embedded: true,
|
||||||
|
},
|
||||||
|
Name: name,
|
||||||
|
Key: peerKey,
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
Hostname: name,
|
||||||
|
GoOS: "proxy",
|
||||||
|
OS: "proxy",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -162,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccessLogEntry struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
ServiceID string `gorm:"index"`
|
||||||
|
Timestamp time.Time `gorm:"index"`
|
||||||
|
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
Method string `gorm:"index"`
|
||||||
|
Host string `gorm:"index"`
|
||||||
|
Path string `gorm:"index"`
|
||||||
|
Duration time.Duration `gorm:"index"`
|
||||||
|
StatusCode int `gorm:"index"`
|
||||||
|
Reason string
|
||||||
|
UserId string `gorm:"index"`
|
||||||
|
AuthMethodUsed string `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||||
|
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
||||||
|
a.ID = serviceLog.GetLogId()
|
||||||
|
a.ServiceID = serviceLog.GetServiceId()
|
||||||
|
a.Timestamp = serviceLog.GetTimestamp().AsTime()
|
||||||
|
a.Method = serviceLog.GetMethod()
|
||||||
|
a.Host = serviceLog.GetHost()
|
||||||
|
a.Path = serviceLog.GetPath()
|
||||||
|
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
|
||||||
|
a.StatusCode = int(serviceLog.GetResponseCode())
|
||||||
|
a.UserId = serviceLog.GetUserId()
|
||||||
|
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
||||||
|
a.AccountID = serviceLog.GetAccountId()
|
||||||
|
|
||||||
|
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||||
|
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
||||||
|
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceLog.GetAuthSuccess() {
|
||||||
|
a.Reason = "Authentication failed"
|
||||||
|
} else if serviceLog.GetResponseCode() >= 400 {
|
||||||
|
a.Reason = "Request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToAPIResponse converts an AccessLogEntry to the API ProxyAccessLog type
|
||||||
|
func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||||
|
var sourceIP *string
|
||||||
|
if a.GeoLocation.ConnectionIP != nil {
|
||||||
|
ip := a.GeoLocation.ConnectionIP.String()
|
||||||
|
sourceIP = &ip
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason *string
|
||||||
|
if a.Reason != "" {
|
||||||
|
reason = &a.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID *string
|
||||||
|
if a.UserId != "" {
|
||||||
|
userID = &a.UserId
|
||||||
|
}
|
||||||
|
|
||||||
|
var authMethod *string
|
||||||
|
if a.AuthMethodUsed != "" {
|
||||||
|
authMethod = &a.AuthMethodUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
var countryCode *string
|
||||||
|
if a.GeoLocation.CountryCode != "" {
|
||||||
|
countryCode = &a.GeoLocation.CountryCode
|
||||||
|
}
|
||||||
|
|
||||||
|
var cityName *string
|
||||||
|
if a.GeoLocation.CityName != "" {
|
||||||
|
cityName = &a.GeoLocation.CityName
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.ProxyAccessLog{
|
||||||
|
Id: a.ID,
|
||||||
|
ServiceId: a.ServiceID,
|
||||||
|
Timestamp: a.Timestamp,
|
||||||
|
Method: a.Method,
|
||||||
|
Host: a.Host,
|
||||||
|
Path: a.Path,
|
||||||
|
DurationMs: int(a.Duration.Milliseconds()),
|
||||||
|
StatusCode: a.StatusCode,
|
||||||
|
SourceIp: sourceIP,
|
||||||
|
Reason: reason,
|
||||||
|
UserId: userID,
|
||||||
|
AuthMethodUsed: authMethod,
|
||||||
|
CountryCode: countryCode,
|
||||||
|
CityName: cityName,
|
||||||
|
}
|
||||||
|
}
|
||||||
124
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
124
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultPageSize is the default number of records per page
|
||||||
|
DefaultPageSize = 50
|
||||||
|
// MaxPageSize is the maximum number of records allowed per page
|
||||||
|
MaxPageSize = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccessLogFilter holds pagination and filtering parameters for access logs
|
||||||
|
type AccessLogFilter struct {
|
||||||
|
// Page is the current page number (1-indexed)
|
||||||
|
Page int
|
||||||
|
// PageSize is the number of records per page
|
||||||
|
PageSize int
|
||||||
|
|
||||||
|
// Filtering parameters
|
||||||
|
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||||
|
SourceIP *string // Filter by source IP address
|
||||||
|
Host *string // Filter by host header
|
||||||
|
Path *string // Filter by request path (supports LIKE pattern)
|
||||||
|
UserID *string // Filter by authenticated user ID
|
||||||
|
UserEmail *string // Filter by user email (requires user lookup)
|
||||||
|
UserName *string // Filter by user name (requires user lookup)
|
||||||
|
Method *string // Filter by HTTP method
|
||||||
|
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
|
||||||
|
StatusCode *int // Filter by HTTP status code
|
||||||
|
StartDate *time.Time // Filter by timestamp >= start_date
|
||||||
|
EndDate *time.Time // Filter by timestamp <= end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
||||||
|
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||||
|
queryParams := r.URL.Query()
|
||||||
|
|
||||||
|
f.Page = 1
|
||||||
|
if pageStr := queryParams.Get("page"); pageStr != "" {
|
||||||
|
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
||||||
|
f.Page = page
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.PageSize = DefaultPageSize
|
||||||
|
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
|
||||||
|
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
||||||
|
f.PageSize = pageSize
|
||||||
|
if f.PageSize > MaxPageSize {
|
||||||
|
f.PageSize = MaxPageSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if search := queryParams.Get("search"); search != "" {
|
||||||
|
f.Search = &search
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
|
||||||
|
f.SourceIP = &sourceIP
|
||||||
|
}
|
||||||
|
|
||||||
|
if host := queryParams.Get("host"); host != "" {
|
||||||
|
f.Host = &host
|
||||||
|
}
|
||||||
|
|
||||||
|
if path := queryParams.Get("path"); path != "" {
|
||||||
|
f.Path = &path
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID := queryParams.Get("user_id"); userID != "" {
|
||||||
|
f.UserID = &userID
|
||||||
|
}
|
||||||
|
|
||||||
|
if userEmail := queryParams.Get("user_email"); userEmail != "" {
|
||||||
|
f.UserEmail = &userEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
if userName := queryParams.Get("user_name"); userName != "" {
|
||||||
|
f.UserName = &userName
|
||||||
|
}
|
||||||
|
|
||||||
|
if method := queryParams.Get("method"); method != "" {
|
||||||
|
f.Method = &method
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := queryParams.Get("status"); status != "" {
|
||||||
|
f.Status = &status
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
|
||||||
|
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
|
||||||
|
f.StatusCode = &statusCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startDate := queryParams.Get("start_date"); startDate != "" {
|
||||||
|
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
|
||||||
|
if err == nil {
|
||||||
|
f.StartDate = &parsedStartDate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if endDate := queryParams.Get("end_date"); endDate != "" {
|
||||||
|
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
|
||||||
|
if err == nil {
|
||||||
|
f.EndDate = &parsedEndDate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOffset calculates the database offset for pagination
|
||||||
|
func (f *AccessLogFilter) GetOffset() int {
|
||||||
|
return (f.Page - 1) * f.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit returns the page size for database queries
|
||||||
|
func (f *AccessLogFilter) GetLimit() int {
|
||||||
|
return f.PageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
expectedPage int
|
||||||
|
expectedPageSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default values when no params provided",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid page and page_size",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "25",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: 25,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page_size exceeds max, should cap at MaxPageSize",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "200",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: MaxPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "invalid",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "invalid",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "0",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "-1",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "0",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
for key, value := range tt.queryParams {
|
||||||
|
q.Set(key, value)
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
|
||||||
|
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetOffset(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
expectedOffset int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first page",
|
||||||
|
page: 1,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second page",
|
||||||
|
page: 2,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "third page with page size 25",
|
||||||
|
page: 3,
|
||||||
|
pageSize: 25,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page 10 with page size 10",
|
||||||
|
page: 10,
|
||||||
|
pageSize: 10,
|
||||||
|
expectedOffset: 90,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: tt.page,
|
||||||
|
PageSize: tt.pageSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := filter.GetOffset()
|
||||||
|
assert.Equal(t, tt.expectedOffset, offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetLimit(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: 2,
|
||||||
|
PageSize: 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := filter.GetLimit()
|
||||||
|
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||||
|
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager accesslogs.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var filter accesslogs.AccessLogFilter
|
||||||
|
filter.ParseFromRequest(r)
|
||||||
|
|
||||||
|
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiLogs := make([]api.ProxyAccessLog, 0, len(logs))
|
||||||
|
for _, log := range logs {
|
||||||
|
apiLogs = append(apiLogs, *log.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &api.ProxyAccessLogsResponse{
|
||||||
|
Data: apiLogs,
|
||||||
|
Page: filter.Page,
|
||||||
|
PageSize: filter.PageSize,
|
||||||
|
TotalRecords: int(totalCount),
|
||||||
|
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTotalPageCount calculates the total number of pages
|
||||||
|
func getTotalPageCount(totalCount, pageSize int) int {
|
||||||
|
if pageSize <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return (totalCount + pageSize - 1) / pageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
geo geolocation.Geolocation
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
geo: geo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccessLog saves an access log entry to the database after enriching it
|
||||||
|
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||||
|
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
||||||
|
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get location for access log source IP [%s]: %v", logEntry.GeoLocation.ConnectionIP.String(), err)
|
||||||
|
} else {
|
||||||
|
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
|
||||||
|
logEntry.GeoLocation.CityName = location.City.Names.En
|
||||||
|
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"service_id": logEntry.ServiceID,
|
||||||
|
"method": logEntry.Method,
|
||||||
|
"host": logEntry.Host,
|
||||||
|
"path": logEntry.Path,
|
||||||
|
"status": logEntry.StatusCode,
|
||||||
|
}).Errorf("failed to save access log: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||||
|
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, 0, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, totalCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveUserFilters converts user email/name filters to user ID filter
|
||||||
|
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||||
|
if filter.UserEmail == nil && filter.UserName == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var matchingUserIDs []string
|
||||||
|
for _, user := range users {
|
||||||
|
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matchingUserIDs) > 0 {
|
||||||
|
filter.UserID = &matchingUserIDs[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeFree Type = "free"
|
||||||
|
TypeCustom Type = "custom"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Domain struct {
|
||||||
|
ID string `gorm:"unique;primaryKey;autoIncrement"`
|
||||||
|
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
TargetCluster string // The proxy cluster this domain should be validated against
|
||||||
|
Type Type `gorm:"-"`
|
||||||
|
Validated bool
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
|
||||||
|
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||||
|
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||||
|
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||||
|
}
|
||||||
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||||
|
switch t {
|
||||||
|
case domain.TypeCustom:
|
||||||
|
return api.ReverseProxyDomainTypeCustom
|
||||||
|
case domain.TypeFree:
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
// By default return as a "free" domain as that is more restrictive.
|
||||||
|
// TODO: is this correct?
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||||
|
resp := api.ReverseProxyDomain{
|
||||||
|
Domain: d.Domain,
|
||||||
|
Id: d.ID,
|
||||||
|
Type: domainTypeToApi(d.Type),
|
||||||
|
Validated: d.Validated,
|
||||||
|
}
|
||||||
|
if d.TargetCluster != "" {
|
||||||
|
resp.TargetCluster = &d.TargetCluster
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := make([]api.ReverseProxyDomain, 0)
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, domainToApi(d))
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type store interface {
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||||
|
|
||||||
|
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||||
|
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
||||||
|
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
||||||
|
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
||||||
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyURLProvider interface {
|
||||||
|
GetConnectedProxyURLs() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
store store
|
||||||
|
validator domain.Validator
|
||||||
|
proxyURLProvider proxyURLProvider
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||||
|
return Manager{
|
||||||
|
store: store,
|
||||||
|
proxyURLProvider: proxyURLProvider,
|
||||||
|
validator: domain.Validator{
|
||||||
|
Resolver: net.DefaultResolver,
|
||||||
|
},
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ret []*domain.Domain
|
||||||
|
|
||||||
|
// Add connected proxy clusters as free domains.
|
||||||
|
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"proxyAllowList": allowList,
|
||||||
|
}).Debug("getting domains with proxy allow list")
|
||||||
|
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
Domain: cluster,
|
||||||
|
AccountID: accountID,
|
||||||
|
Type: domain.TypeFree,
|
||||||
|
Validated: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom domains.
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
ID: d.ID,
|
||||||
|
Domain: d.Domain,
|
||||||
|
AccountID: accountID,
|
||||||
|
TargetCluster: d.TargetCluster,
|
||||||
|
Type: domain.TypeCustom,
|
||||||
|
Validated: d.Validated,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the target cluster is in the available clusters
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
clusterValid := false
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
if cluster == targetCluster {
|
||||||
|
clusterValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clusterValid {
|
||||||
|
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt an initial validation against the specified cluster only
|
||||||
|
var validated bool
|
||||||
|
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
|
||||||
|
validated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
|
||||||
|
if err != nil {
|
||||||
|
return d, fmt.Errorf("create domain in store: %w", err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
||||||
|
// TODO: check for "no records" type error. Because that is a success condition.
|
||||||
|
return fmt.Errorf("delete domain from store: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).Info("starting domain validation")
|
||||||
|
|
||||||
|
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("get custom domain from store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate only against the domain's target cluster
|
||||||
|
targetCluster := d.TargetCluster
|
||||||
|
if targetCluster == "" {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Warn("domain has no target cluster set, skipping validation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Info("validating domain against target cluster")
|
||||||
|
|
||||||
|
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Info("domain validated successfully")
|
||||||
|
d.Validated = true
|
||||||
|
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).WithError(err).Error("update custom domain in store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Warn("domain validation failed - CNAME does not match target cluster")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||||
|
// their URLs
|
||||||
|
func (m Manager) proxyURLAllowList() []string {
|
||||||
|
var reverseProxyAddresses []string
|
||||||
|
if m.proxyURLProvider != nil {
|
||||||
|
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||||
|
}
|
||||||
|
return reverseProxyAddresses
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||||
|
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||||
|
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||||
|
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
if len(allowList) == 0 {
|
||||||
|
return "", fmt.Errorf("no proxy clusters available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
|
||||||
|
return cluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
|
||||||
|
if valid {
|
||||||
|
return targetCluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
|
||||||
|
for _, customDomain := range customDomains {
|
||||||
|
if strings.HasSuffix(domain, "."+customDomain.Domain) {
|
||||||
|
return customDomain.TargetCluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
|
||||||
|
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
|
||||||
|
// It matches the domain suffix against available clusters and returns the matching cluster.
|
||||||
|
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
|
||||||
|
for _, cluster := range availableClusters {
|
||||||
|
if strings.HasSuffix(domain, "."+cluster) {
|
||||||
|
return cluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Validator struct {
|
||||||
|
Resolver resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValidator initializes a validator with a specific DNS Resolver.
|
||||||
|
// If a Validator is used without specifying a Resolver, then it will
|
||||||
|
// use the net.DefaultResolver.
|
||||||
|
func NewValidator(resolver resolver) *Validator {
|
||||||
|
return &Validator{
|
||||||
|
Resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid looks up the CNAME record for the passed domain with a prefix
|
||||||
|
// and compares it against the acceptable domains.
|
||||||
|
// If the returned CNAME matches any accepted domain, it will return true,
|
||||||
|
// otherwise, including in the event of a DNS error, it will return false.
|
||||||
|
// The comparison is very simple, so wildcards will not match if included
|
||||||
|
// in the acceptable domain list.
|
||||||
|
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
|
||||||
|
_, valid := v.ValidateWithCluster(ctx, domain, accept)
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
|
||||||
|
// Returns the cluster address and true if valid, or empty string and false if invalid.
|
||||||
|
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
|
||||||
|
if v.Resolver == nil {
|
||||||
|
v.Resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
lookupDomain := "validation." + domain
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("looking up CNAME for domain validation")
|
||||||
|
|
||||||
|
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
}).WithError(err).Warn("CNAME lookup failed for domain validation")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
nakedCNAME := strings.TrimSuffix(cname, ".")
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": cname,
|
||||||
|
"nakedCNAME": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("CNAME lookup result for domain validation")
|
||||||
|
|
||||||
|
for _, acceptDomain := range accept {
|
||||||
|
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
|
||||||
|
if nakedCNAME == normalizedAccept {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"cluster": acceptDomain,
|
||||||
|
}).Info("domain CNAME matched cluster")
|
||||||
|
return acceptDomain, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Warn("domain CNAME does not match any accepted cluster")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package domain_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver struct {
|
||||||
|
CNAME string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
|
||||||
|
return r.CNAME, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
tests := map[string]struct {
|
||||||
|
resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
domain string
|
||||||
|
accept []string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
"match": {
|
||||||
|
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
"no match": {
|
||||||
|
resolver: resolver{"invalid"},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: false,
|
||||||
|
},
|
||||||
|
"accept trailing dot": {
|
||||||
|
resolver: resolver{"bar.example.com."},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
validator := domain.NewValidator(test.resolver)
|
||||||
|
actual := validator.IsValid(t.Context(), test.domain, test.accept)
|
||||||
|
if test.expect != actual {
|
||||||
|
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
23
management/internals/modules/reverseproxy/interface.go
Normal file
23
management/internals/modules/reverseproxy/interface.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||||
|
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||||
|
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
|
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
|
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||||
|
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||||
|
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||||
|
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||||
|
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||||
|
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||||
|
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
||||||
|
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||||
|
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||||
|
}
|
||||||
225
management/internals/modules/reverseproxy/interface_mock.go
Normal file
225
management/internals/modules/reverseproxy/interface_mock.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: ./interface.go
|
||||||
|
|
||||||
|
// Package reverseproxy is a generated GoMock package.
|
||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockManager is a mock of Manager interface.
|
||||||
|
type MockManager struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockManagerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||||
|
type MockManagerMockRecorder struct {
|
||||||
|
mock *MockManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockManager creates a new mock instance.
|
||||||
|
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||||
|
mock := &MockManager{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockManagerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateService mocks base method.
|
||||||
|
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateService indicates an expected call of CreateService.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteService mocks base method.
|
||||||
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteService indicates an expected call of DeleteService.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountServices mocks base method.
|
||||||
|
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountServices indicates an expected call of GetAccountServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllServices mocks base method.
|
||||||
|
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllServices indicates an expected call of GetAllServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalServices mocks base method.
|
||||||
|
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalServices indicates an expected call of GetGlobalServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetService mocks base method.
|
||||||
|
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetService indicates an expected call of GetService.
|
||||||
|
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByID mocks base method.
|
||||||
|
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByID indicates an expected call of GetServiceByID.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceIDByTargetID mocks base method.
|
||||||
|
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadAllServicesForAccount mocks base method.
|
||||||
|
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
|
||||||
|
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadService mocks base method.
|
||||||
|
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadService indicates an expected call of ReloadService.
|
||||||
|
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt mocks base method.
|
||||||
|
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
|
||||||
|
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus mocks base method.
|
||||||
|
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus indicates an expected call of SetStatus.
|
||||||
|
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateService mocks base method.
|
||||||
|
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateService indicates an expected call of UpdateService.
|
||||||
|
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
||||||
|
}
|
||||||
170
management/internals/modules/reverseproxy/manager/api.go
Normal file
170
management/internals/modules/reverseproxy/manager/api.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
|
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager reverseproxy.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterEndpoints registers all service HTTP endpoints.
|
||||||
|
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||||
|
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
||||||
|
|
||||||
|
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||||
|
|
||||||
|
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServices := make([]*api.Service, 0, len(allServices))
|
||||||
|
for _, service := range allServices {
|
||||||
|
apiServices = append(apiServices, service.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, apiServices)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ServiceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service := new(reverseproxy.Service)
|
||||||
|
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||||
|
|
||||||
|
if err = service.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ServiceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service := new(reverseproxy.Service)
|
||||||
|
service.ID = serviceID
|
||||||
|
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||||
|
|
||||||
|
if err = service.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
500
management/internals/modules/reverseproxy/manager/manager.go
Normal file
500
management/internals/modules/reverseproxy/manager/manager.go
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const unknownHostPlaceholder = "unknown"
|
||||||
|
|
||||||
|
// ClusterDeriver derives the proxy cluster from a domain.
|
||||||
|
type ClusterDeriver interface {
|
||||||
|
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||||
|
clusterDeriver ClusterDeriver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new service manager.
|
||||||
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
proxyGRPCServer: proxyGRPCServer,
|
||||||
|
clusterDeriver: clusterDeriver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||||
|
for _, target := range service.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case reverseproxy.TargetTypePeer:
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = peer.IP.String()
|
||||||
|
case reverseproxy.TargetTypeHost:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Prefix.Addr().String()
|
||||||
|
case reverseproxy.TargetTypeDomain:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Domain
|
||||||
|
case reverseproxy.TargetTypeSubnet:
|
||||||
|
// For subnets we do not do any lookups on the resource
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyCluster string
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
service.AccountID = accountID
|
||||||
|
service.ProxyCluster = proxyCluster
|
||||||
|
service.InitNewRecord()
|
||||||
|
err = service.Auth.HashSecrets()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate session JWT signing keys
|
||||||
|
keyPair, err := sessionkey.GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate session keys: %w", err)
|
||||||
|
}
|
||||||
|
service.SessionPrivateKey = keyPair.PrivateKey
|
||||||
|
service.SessionPublicKey = keyPair.PublicKey
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
// Check for duplicate domain
|
||||||
|
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("failed to check existing service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existingService != nil {
|
||||||
|
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to create service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldCluster string
|
||||||
|
var domainChanged bool
|
||||||
|
var serviceEnabledChanged bool
|
||||||
|
|
||||||
|
err = service.Auth.HashSecrets()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldCluster = existingService.ProxyCluster
|
||||||
|
|
||||||
|
if existingService.Domain != service.Domain {
|
||||||
|
domainChanged = true
|
||||||
|
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("check existing service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conflictService != nil && conflictService.ID != service.ID {
|
||||||
|
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||||
|
}
|
||||||
|
service.ProxyCluster = newCluster
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
service.ProxyCluster = existingService.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||||
|
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||||
|
service.Auth.PasswordAuth.Password == "" {
|
||||||
|
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||||
|
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||||
|
service.Auth.PinAuth.Pin == "" {
|
||||||
|
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta = existingService.Meta
|
||||||
|
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||||
|
service.SessionPublicKey = existingService.SessionPublicKey
|
||||||
|
serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||||
|
|
||||||
|
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("update service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||||
|
switch {
|
||||||
|
case domainChanged && oldCluster != service.ProxyCluster:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||||
|
case !service.Enabled && serviceEnabledChanged:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||||
|
case service.Enabled && serviceEnabledChanged:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||||
|
default:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||||
|
}
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||||
|
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||||
|
for _, target := range targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case reverseproxy.TargetTypePeer:
|
||||||
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||||
|
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var service *reverseproxy.Service
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||||
|
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||||
|
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.CertificateIssuedAt = time.Now()
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||||
|
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.Status = string(status)
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||||
|
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return target.ServiceID, nil
|
||||||
|
}
|
||||||
463
management/internals/modules/reverseproxy/reverseproxy.go
Normal file
463
management/internals/modules/reverseproxy/reverseproxy.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Operation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Create Operation = "create"
|
||||||
|
Update Operation = "update"
|
||||||
|
Delete Operation = "delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusPending ProxyStatus = "pending"
|
||||||
|
StatusActive ProxyStatus = "active"
|
||||||
|
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||||
|
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||||
|
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||||
|
StatusError ProxyStatus = "error"
|
||||||
|
|
||||||
|
TargetTypePeer = "peer"
|
||||||
|
TargetTypeHost = "host"
|
||||||
|
TargetTypeDomain = "domain"
|
||||||
|
TargetTypeSubnet = "subnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Target struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"-"`
|
||||||
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
|
Path *string `json:"path,omitempty"`
|
||||||
|
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||||
|
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||||
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
|
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||||
|
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PasswordAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PINAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Pin string `json:"pin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BearerAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) HashSecrets() error {
|
||||||
|
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
||||||
|
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
a.PasswordAuth.Password = hashedPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
||||||
|
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash pin: %w", err)
|
||||||
|
}
|
||||||
|
a.PinAuth.Pin = hashedPin
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) ClearSecrets() {
|
||||||
|
if a.PasswordAuth != nil {
|
||||||
|
a.PasswordAuth.Password = ""
|
||||||
|
}
|
||||||
|
if a.PinAuth != nil {
|
||||||
|
a.PinAuth.Pin = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCValidationConfig struct {
|
||||||
|
Issuer string
|
||||||
|
Audiences []string
|
||||||
|
KeysLocation string
|
||||||
|
MaxTokenAgeSeconds int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceMeta struct {
|
||||||
|
CreatedAt time.Time
|
||||||
|
CertificateIssuedAt time.Time
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
Name string
|
||||||
|
Domain string `gorm:"index"`
|
||||||
|
ProxyCluster string `gorm:"index"`
|
||||||
|
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||||
|
Enabled bool
|
||||||
|
PassHostHeader bool
|
||||||
|
RewriteRedirects bool
|
||||||
|
Auth AuthConfig `gorm:"serializer:json"`
|
||||||
|
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||||
|
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||||
|
for _, target := range targets {
|
||||||
|
target.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Service{
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: name,
|
||||||
|
Domain: domain,
|
||||||
|
ProxyCluster: proxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: enabled,
|
||||||
|
}
|
||||||
|
s.InitNewRecord()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||||
|
// Service record. This overwrites any existing ID and Meta fields and should
|
||||||
|
// only be called during initial creation, not for updates.
|
||||||
|
func (s *Service) InitNewRecord() {
|
||||||
|
s.ID = xid.New().String()
|
||||||
|
s.Meta = ServiceMeta{
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Status: string(StatusPending),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToAPIResponse() *api.Service {
|
||||||
|
s.Auth.ClearSecrets()
|
||||||
|
|
||||||
|
authConfig := api.ServiceAuthConfig{}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil {
|
||||||
|
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||||
|
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: s.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil {
|
||||||
|
authConfig.PinAuth = &api.PINAuthConfig{
|
||||||
|
Enabled: s.Auth.PinAuth.Enabled,
|
||||||
|
Pin: s.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil {
|
||||||
|
authConfig.BearerAuth = &api.BearerAuthConfig{
|
||||||
|
Enabled: s.Auth.BearerAuth.Enabled,
|
||||||
|
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert internal targets to API targets
|
||||||
|
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
apiTargets = append(apiTargets, api.ServiceTarget{
|
||||||
|
Path: target.Path,
|
||||||
|
Host: &target.Host,
|
||||||
|
Port: target.Port,
|
||||||
|
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||||
|
TargetId: target.TargetId,
|
||||||
|
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||||
|
Enabled: target.Enabled,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := api.ServiceMeta{
|
||||||
|
CreatedAt: s.Meta.CreatedAt,
|
||||||
|
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.Meta.CertificateIssuedAt.IsZero() {
|
||||||
|
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &api.Service{
|
||||||
|
Id: s.ID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Targets: apiTargets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: &s.PassHostHeader,
|
||||||
|
RewriteRedirects: &s.RewriteRedirects,
|
||||||
|
Auth: authConfig,
|
||||||
|
Meta: meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ProxyCluster != "" {
|
||||||
|
resp.ProxyCluster = &s.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
||||||
|
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
if !target.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make path prefix stripping configurable per-target.
|
||||||
|
// Currently the matching prefix is baked into the target URL path,
|
||||||
|
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||||
|
targetURL := url.URL{
|
||||||
|
Scheme: target.Protocol,
|
||||||
|
Host: target.Host,
|
||||||
|
Path: "/", // TODO: support service path
|
||||||
|
}
|
||||||
|
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||||
|
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/"
|
||||||
|
if target.Path != nil {
|
||||||
|
path = *target.Path
|
||||||
|
}
|
||||||
|
pathMappings = append(pathMappings, &proto.PathMapping{
|
||||||
|
Path: path,
|
||||||
|
Target: targetURL.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := &proto.Authentication{
|
||||||
|
SessionKey: s.SessionPublicKey,
|
||||||
|
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
||||||
|
auth.Password = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
||||||
|
auth.Pin = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||||
|
auth.Oidc = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.ProxyMapping{
|
||||||
|
Type: operationToProtoType(operation),
|
||||||
|
Id: s.ID,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Path: pathMappings,
|
||||||
|
AuthToken: authToken,
|
||||||
|
Auth: auth,
|
||||||
|
AccountId: s.AccountID,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||||
|
switch op {
|
||||||
|
case Create:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
case Update:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||||
|
case Delete:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||||
|
default:
|
||||||
|
log.Fatalf("unknown operation type: %v", op)
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||||
|
// (443 for https, 80 for http).
|
||||||
|
func isDefaultPort(scheme string, port int) bool {
|
||||||
|
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||||
|
s.Name = req.Name
|
||||||
|
s.Domain = req.Domain
|
||||||
|
s.AccountID = accountID
|
||||||
|
|
||||||
|
targets := make([]*Target, 0, len(req.Targets))
|
||||||
|
for _, apiTarget := range req.Targets {
|
||||||
|
target := &Target{
|
||||||
|
AccountID: accountID,
|
||||||
|
Path: apiTarget.Path,
|
||||||
|
Port: apiTarget.Port,
|
||||||
|
Protocol: string(apiTarget.Protocol),
|
||||||
|
TargetId: apiTarget.TargetId,
|
||||||
|
TargetType: string(apiTarget.TargetType),
|
||||||
|
Enabled: apiTarget.Enabled,
|
||||||
|
}
|
||||||
|
if apiTarget.Host != nil {
|
||||||
|
target.Host = *apiTarget.Host
|
||||||
|
}
|
||||||
|
targets = append(targets, target)
|
||||||
|
}
|
||||||
|
s.Targets = targets
|
||||||
|
|
||||||
|
s.Enabled = req.Enabled
|
||||||
|
|
||||||
|
if req.PassHostHeader != nil {
|
||||||
|
s.PassHostHeader = *req.PassHostHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.RewriteRedirects != nil {
|
||||||
|
s.RewriteRedirects = *req.RewriteRedirects
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PasswordAuth != nil {
|
||||||
|
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||||
|
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: req.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PinAuth != nil {
|
||||||
|
s.Auth.PinAuth = &PINAuthConfig{
|
||||||
|
Enabled: req.Auth.PinAuth.Enabled,
|
||||||
|
Pin: req.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.BearerAuth != nil {
|
||||||
|
bearerAuth := &BearerAuthConfig{
|
||||||
|
Enabled: req.Auth.BearerAuth.Enabled,
|
||||||
|
}
|
||||||
|
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||||
|
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||||
|
}
|
||||||
|
s.Auth.BearerAuth = bearerAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Validate() error {
|
||||||
|
if s.Name == "" {
|
||||||
|
return errors.New("service name is required")
|
||||||
|
}
|
||||||
|
if len(s.Name) > 255 {
|
||||||
|
return errors.New("service name exceeds maximum length of 255 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Domain == "" {
|
||||||
|
return errors.New("service domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.Targets) == 0 {
|
||||||
|
return errors.New("at least one target is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
|
// host field will be ignored
|
||||||
|
case TargetTypeSubnet:
|
||||||
|
if target.Host == "" {
|
||||||
|
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
if target.TargetId == "" {
|
||||||
|
return fmt.Errorf("target %d has empty target_id", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Copy() *Service {
|
||||||
|
targets := make([]*Target, len(s.Targets))
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
targetCopy := *target
|
||||||
|
targets[i] = &targetCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
ID: s.ID,
|
||||||
|
AccountID: s.AccountID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
ProxyCluster: s.ProxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
Auth: s.Auth,
|
||||||
|
Meta: s.Meta,
|
||||||
|
SessionPrivateKey: s.SessionPrivateKey,
|
||||||
|
SessionPublicKey: s.SessionPublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
405
management/internals/modules/reverseproxy/reverseproxy_test.go
Normal file
405
management/internals/modules/reverseproxy/reverseproxy_test.go
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validProxy() *Service {
|
||||||
|
return &Service{
|
||||||
|
Name: "test",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Valid(t *testing.T) {
|
||||||
|
require.NoError(t, validProxy().Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyName(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Name = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyDomain(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Domain = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_NoTargets(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = nil
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetId = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_InvalidTargetType(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetType = "invalid"
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_ResourceTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "resource-1",
|
||||||
|
TargetType: TargetTypeHost,
|
||||||
|
Host: "example.org",
|
||||||
|
Port: 443,
|
||||||
|
Protocol: "https",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.2",
|
||||||
|
Port: 80,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
err := rp.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "target 1")
|
||||||
|
assert.Contains(t, err.Error(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDefaultPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
scheme string
|
||||||
|
port int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"http", 80, true},
|
||||||
|
{"https", 443, true},
|
||||||
|
{"http", 443, false},
|
||||||
|
{"https", 80, false},
|
||||||
|
{"http", 8080, false},
|
||||||
|
{"https", 8443, false},
|
||||||
|
{"http", 0, false},
|
||||||
|
{"https", 0, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||||
|
oidcConfig := OIDCValidationConfig{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
wantTarget string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "http with default port 80 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with default port 443 omits port",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "https://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port 0 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 0,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-default port is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8080,
|
||||||
|
wantTarget: "http://10.0.0.1:8080/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with non-default port is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8443,
|
||||||
|
wantTarget: "https://10.0.0.1:8443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http port 443 is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "http://10.0.0.1:443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https port 80 is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "https://10.0.0.1:80/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: tt.host,
|
||||||
|
Port: tt.port,
|
||||||
|
Protocol: tt.protocol,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
||||||
|
require.Len(t, pm.Path, 1, "should have one path mapping")
|
||||||
|
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
||||||
|
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
tests := []struct {
|
||||||
|
op Operation
|
||||||
|
want proto.ProxyMappingUpdateType
|
||||||
|
}{
|
||||||
|
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
||||||
|
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
||||||
|
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.op), func(t *testing.T) {
|
||||||
|
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
||||||
|
assert.Equal(t, tt.want, pm.Type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *AuthConfig
|
||||||
|
wantErr bool
|
||||||
|
validate func(*testing.T, *AuthConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hash password successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "testPassword123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash PIN successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "123456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash both password and PIN",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "9999",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Password hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("PIN hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip disabled password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "password" {
|
||||||
|
t.Errorf("Disabled password auth should not be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip empty password",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Empty password should remain empty")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip nil password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: nil,
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "1234",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth != nil {
|
||||||
|
t.Errorf("PasswordAuth should remain nil")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN should still be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.HashSecrets()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, tt.config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "correctPassword",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.HashSecrets(); err != nil {
|
||||||
|
t.Fatalf("HashSecrets() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify with wrong password should fail
|
||||||
|
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
||||||
|
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||||
|
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "hashedPassword",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "hashedPin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ClearSecrets()
|
||||||
|
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
if config.PinAuth.Pin != "" {
|
||||||
|
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package sessionkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
PrivateKey string
|
||||||
|
PublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Method auth.Method `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKeyPair() (*KeyPair, error) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeyPair{
|
||||||
|
PrivateKey: base64.StdEncoding.EncodeToString(priv),
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||||
|
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(privKeyBytes) != ed25519.PrivateKeySize {
|
||||||
|
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey := ed25519.PrivateKey(privKeyBytes)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: auth.SessionJWTIssuer,
|
||||||
|
Subject: userID,
|
||||||
|
Audience: jwt.ClaimStrings{domain},
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
Method: method,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
signedToken, err := token.SignedString(privKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("sign token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken, nil
|
||||||
|
}
|
||||||
@@ -21,6 +21,8 @@ import (
|
|||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -92,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -120,11 +122,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||||
}
|
}
|
||||||
|
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
||||||
|
s.proxyAuthClose = proxyAuthClose
|
||||||
gRPCOpts := []grpc.ServerOption{
|
gRPCOpts := []grpc.ServerOption{
|
||||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||||
grpc.KeepaliveParams(kasp),
|
grpc.KeepaliveParams(kasp),
|
||||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
||||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
@@ -150,10 +154,53 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
}
|
}
|
||||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||||
|
|
||||||
|
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||||
|
log.Info("ProxyService registered on gRPC server")
|
||||||
|
|
||||||
return gRPCAPIHandler
|
return gRPCAPIHandler
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
proxyService.SetProxyManager(s.ReverseProxyManager())
|
||||||
|
})
|
||||||
|
return proxyService
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||||
|
return Create(s, func() nbgrpc.ProxyOIDCConfig {
|
||||||
|
return nbgrpc.ProxyOIDCConfig{
|
||||||
|
Issuer: s.Config.HttpConfig.AuthIssuer,
|
||||||
|
// todo: double check auth clientID value
|
||||||
|
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
|
||||||
|
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
|
||||||
|
Audience: s.Config.HttpConfig.AuthAudience,
|
||||||
|
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||||
|
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||||
|
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||||
|
log.Info("One-time token store initialized for proxy authentication")
|
||||||
|
return tokenStore
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||||
|
return Create(s, func() accesslogs.Manager {
|
||||||
|
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||||
|
return accessLogManager
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||||
// Load server's certificate and private key
|
// Load server's certificate and private key
|
||||||
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
||||||
|
|||||||
@@ -100,6 +100,8 @@ type HttpServerConfig struct {
|
|||||||
CertFile string
|
CertFile string
|
||||||
// CertKey is the location of the certificate private key
|
// CertKey is the location of the certificate private key
|
||||||
CertKey string
|
CertKey string
|
||||||
|
// AuthClientID is the client id used for proxy SSO auth
|
||||||
|
AuthClientID string
|
||||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||||
AuthAudience string
|
AuthAudience string
|
||||||
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
||||||
@@ -117,6 +119,8 @@ type HttpServerConfig struct {
|
|||||||
IdpSignKeyRefreshEnabled bool
|
IdpSignKeyRefreshEnabled bool
|
||||||
// Extra audience
|
// Extra audience
|
||||||
ExtraAuthAudience string
|
ExtraAuthAudience string
|
||||||
|
// AuthCallbackDomain contains the callback domain
|
||||||
|
AuthCallbackURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
@@ -98,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create account manager: %v", err)
|
log.Fatalf("failed to create account manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
accountManager.SetServiceManager(s.ReverseProxyManager())
|
||||||
|
})
|
||||||
|
|
||||||
return accountManager
|
return accountManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -154,7 +162,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||||
return Create(s, func() resources.Manager {
|
return Create(s, func() resources.Manager {
|
||||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
|
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,3 +189,16 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
||||||
|
return Create(s, func() reverseproxy.Manager {
|
||||||
|
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||||
|
return Create(s, func() *manager.Manager {
|
||||||
|
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
||||||
|
return &m
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
@@ -21,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/metrics"
|
"github.com/netbirdio/netbird/management/server/metrics"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
@@ -58,6 +58,8 @@ type BaseServer struct {
|
|||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
|
|
||||||
|
proxyAuthClose func()
|
||||||
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
certManager *autocert.Manager
|
certManager *autocert.Manager
|
||||||
update *version.Update
|
update *version.Update
|
||||||
@@ -215,6 +217,11 @@ func (s *BaseServer) Stop() error {
|
|||||||
_ = s.certManager.Listener().Close()
|
_ = s.certManager.Listener().Close()
|
||||||
}
|
}
|
||||||
s.GRPCServer().Stop()
|
s.GRPCServer().Stop()
|
||||||
|
s.ReverseProxyGRPCServer().Close()
|
||||||
|
if s.proxyAuthClose != nil {
|
||||||
|
s.proxyAuthClose()
|
||||||
|
s.proxyAuthClose = nil
|
||||||
|
}
|
||||||
_ = s.Store().Close(ctx)
|
_ = s.Store().Close(ctx)
|
||||||
_ = s.EventStore().Close(ctx)
|
_ = s.EventStore().Close(ctx)
|
||||||
if s.update != nil {
|
if s.update != nil {
|
||||||
|
|||||||
167
management/internals/shared/grpc/onetime_token.go
Normal file
167
management/internals/shared/grpc/onetime_token.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||||
|
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||||
|
// a service is created and must be used exactly once by the proxy
|
||||||
|
// to authenticate a subsequent RPC call.
|
||||||
|
type OneTimeTokenStore struct {
|
||||||
|
tokens map[string]*tokenMetadata
|
||||||
|
mu sync.RWMutex
|
||||||
|
cleanup *time.Ticker
|
||||||
|
cleanupDone chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenMetadata stores information about a one-time token
|
||||||
|
type tokenMetadata struct {
|
||||||
|
ServiceID string
|
||||||
|
AccountID string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||||
|
// of expired tokens. The cleanupInterval determines how often expired
|
||||||
|
// tokens are removed from memory.
|
||||||
|
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||||
|
store := &OneTimeTokenStore{
|
||||||
|
tokens: make(map[string]*tokenMetadata),
|
||||||
|
cleanup: time.NewTicker(cleanupInterval),
|
||||||
|
cleanupDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start background cleanup goroutine
|
||||||
|
go store.cleanupExpired()
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken creates a new cryptographically secure one-time token
|
||||||
|
// with the specified TTL. The token is associated with a specific
|
||||||
|
// accountID and serviceID for validation purposes.
|
||||||
|
//
|
||||||
|
// Returns the generated token string or an error if random generation fails.
|
||||||
|
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||||
|
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||||
|
randomBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(randomBytes); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||||
|
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.tokens[token] = &tokenMetadata{
|
||||||
|
ServiceID: serviceID,
|
||||||
|
AccountID: accountID,
|
||||||
|
ExpiresAt: time.Now().Add(ttl),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||||
|
serviceID, accountID, ttl)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndConsume verifies the token against the provided accountID and
|
||||||
|
// serviceID, checks expiration, and then deletes it to enforce single-use.
|
||||||
|
//
|
||||||
|
// This method uses constant-time comparison to prevent timing attacks.
|
||||||
|
//
|
||||||
|
// Returns nil on success, or an error if:
|
||||||
|
// - Token doesn't exist
|
||||||
|
// - Token has expired
|
||||||
|
// - Account ID doesn't match
|
||||||
|
// - Reverse proxy ID doesn't match
|
||||||
|
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
metadata, exists := s.tokens[token]
|
||||||
|
if !exists {
|
||||||
|
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||||
|
serviceID, accountID)
|
||||||
|
return fmt.Errorf("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiration
|
||||||
|
if time.Now().After(metadata.ExpiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||||
|
serviceID, accountID)
|
||||||
|
return fmt.Errorf("token expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||||
|
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||||
|
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||||
|
metadata.AccountID, accountID)
|
||||||
|
return fmt.Errorf("account ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate service ID using constant-time comparison
|
||||||
|
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||||
|
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||||
|
metadata.ServiceID, serviceID)
|
||||||
|
return fmt.Errorf("service ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete token immediately to enforce single-use
|
||||||
|
delete(s.tokens, token)
|
||||||
|
|
||||||
|
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||||
|
serviceID, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||||
|
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.cleanup.C:
|
||||||
|
s.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
removed := 0
|
||||||
|
for token, metadata := range s.tokens {
|
||||||
|
if now.After(metadata.ExpiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if removed > 0 {
|
||||||
|
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
case <-s.cleanupDone:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup goroutine and releases resources
|
||||||
|
func (s *OneTimeTokenStore) Close() {
|
||||||
|
s.cleanup.Stop()
|
||||||
|
close(s.cleanupDone)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||||
|
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return len(s.tokens)
|
||||||
|
}
|
||||||
1065
management/internals/shared/grpc/proxy.go
Normal file
1065
management/internals/shared/grpc/proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
234
management/internals/shared/grpc/proxy_auth.go
Normal file
234
management/internals/shared/grpc/proxy_auth.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
|
||||||
|
lastUsedUpdateInterval = time.Minute
|
||||||
|
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
|
||||||
|
lastUsedCleanupInterval = 2 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyTokenContextKey struct{}
|
||||||
|
|
||||||
|
// ProxyTokenContextKey is the typed key used to store validated token info in context.
|
||||||
|
var ProxyTokenContextKey = proxyTokenContextKey{}
|
||||||
|
|
||||||
|
// proxyTokenID identifies a proxy access token by its database ID.
|
||||||
|
type proxyTokenID = string
|
||||||
|
|
||||||
|
// proxyTokenStore defines the store interface needed for token validation
|
||||||
|
type proxyTokenStore interface {
|
||||||
|
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||||
|
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyAuthInterceptor holds state for proxy authentication interceptors.
|
||||||
|
type proxyAuthInterceptor struct {
|
||||||
|
store proxyTokenStore
|
||||||
|
failureLimiter *authFailureLimiter
|
||||||
|
|
||||||
|
// lastUsedMu protects lastUsedTimes
|
||||||
|
lastUsedMu sync.Mutex
|
||||||
|
lastUsedTimes map[proxyTokenID]time.Time
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
i := &proxyAuthInterceptor{
|
||||||
|
store: tokenStore,
|
||||||
|
failureLimiter: newAuthFailureLimiter(),
|
||||||
|
lastUsedTimes: make(map[proxyTokenID]time.Time),
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go i.lastUsedCleanupLoop(ctx)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
|
||||||
|
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
|
||||||
|
// The returned close function must be called on shutdown to stop background goroutines.
|
||||||
|
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
|
||||||
|
interceptor := newProxyAuthInterceptor(tokenStore)
|
||||||
|
|
||||||
|
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||||
|
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := interceptor.validateProxyToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||||
|
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := interceptor.validateProxyToken(ss.Context())
|
||||||
|
if err != nil {
|
||||||
|
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||||
|
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
|
||||||
|
wrapped := &wrappedServerStream{
|
||||||
|
ServerStream: ss,
|
||||||
|
ctx: ctx,
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(srv, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unary, stream, interceptor.close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
|
clientIP := peerIPFromContext(ctx)
|
||||||
|
|
||||||
|
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||||
|
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := i.doValidateProxyToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if clientIP != "" {
|
||||||
|
i.failureLimiter.recordFailure(clientIP)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i.maybeUpdateLastUsed(ctx, token.ID)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValues := md.Get("authorization")
|
||||||
|
if len(authValues) == 0 {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValue := authValues[0]
|
||||||
|
if !strings.HasPrefix(authValue, "Bearer ") {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
|
||||||
|
}
|
||||||
|
|
||||||
|
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
|
||||||
|
|
||||||
|
if err := plainToken.Validate(); err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||||
|
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||||
|
|
||||||
|
if !token.IsValid() {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
|
||||||
|
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
i.lastUsedMu.Lock()
|
||||||
|
lastUpdate, exists := i.lastUsedTimes[tokenID]
|
||||||
|
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
|
||||||
|
i.lastUsedMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
i.lastUsedTimes[tokenID] = now
|
||||||
|
i.lastUsedMu.Unlock()
|
||||||
|
|
||||||
|
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(lastUsedCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
i.cleanupStaleLastUsed()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupStaleLastUsed removes entries older than 2x the update interval.
|
||||||
|
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
|
||||||
|
i.lastUsedMu.Lock()
|
||||||
|
defer i.lastUsedMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
staleThreshold := 2 * lastUsedUpdateInterval
|
||||||
|
for id, lastUpdate := range i.lastUsedTimes {
|
||||||
|
if now.Sub(lastUpdate) > staleThreshold {
|
||||||
|
delete(i.lastUsedTimes, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) close() {
|
||||||
|
i.cancel()
|
||||||
|
i.failureLimiter.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyTokenFromContext retrieves the validated proxy token from the context
|
||||||
|
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
|
||||||
|
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
|
||||||
|
type wrappedServerStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedServerStream) Context() context.Context {
|
||||||
|
return w.ctx
|
||||||
|
}
|
||||||
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
|
||||||
|
proxyAuthFailureBurst = 5
|
||||||
|
// proxyAuthLimiterCleanup is how often stale limiters are removed.
|
||||||
|
proxyAuthLimiterCleanup = 5 * time.Minute
|
||||||
|
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
|
||||||
|
proxyAuthLimiterTTL = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
|
||||||
|
// One token every 12 seconds = 5 per minute.
|
||||||
|
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
|
||||||
|
|
||||||
|
// clientIP identifies a client by its IP address for rate limiting purposes.
|
||||||
|
type clientIP = string
|
||||||
|
|
||||||
|
type limiterEntry struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastAccess time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
|
||||||
|
type authFailureLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limiters map[clientIP]*limiterEntry
|
||||||
|
failureRate rate.Limit
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthFailureLimiter() *authFailureLimiter {
|
||||||
|
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l := &authFailureLimiter{
|
||||||
|
limiters: make(map[clientIP]*limiterEntry),
|
||||||
|
failureRate: failureRate,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go l.cleanupLoop(ctx)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// isLimited returns true if the given IP has exhausted its failure budget.
|
||||||
|
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
entry, exists := l.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.limiter.Tokens() < 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordFailure consumes a token from the rate limiter for the given IP.
|
||||||
|
func (l *authFailureLimiter) recordFailure(ip clientIP) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
entry, exists := l.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
entry = &limiterEntry{
|
||||||
|
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
|
||||||
|
}
|
||||||
|
l.limiters[ip] = entry
|
||||||
|
}
|
||||||
|
entry.lastAccess = now
|
||||||
|
entry.limiter.Allow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(proxyAuthLimiterCleanup)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
l.cleanup()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) cleanup() {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for ip, entry := range l.limiters {
|
||||||
|
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
|
||||||
|
delete(l.limiters, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) stop() {
|
||||||
|
l.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerIPFromContext extracts the client IP from the gRPC context.
|
||||||
|
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||||
|
func peerIPFromContext(ctx context.Context) clientIP {
|
||||||
|
if addr, ok := realip.FromContext(ctx); ok {
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p, ok := peer.FromContext(ctx); ok {
|
||||||
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||||
|
if err != nil {
|
||||||
|
return p.Addr.String()
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
ip := "192.168.1.1"
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure("192.168.1.1")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, l.isLimited("192.168.1.1"))
|
||||||
|
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
ip := "10.0.0.1"
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure(ip)
|
||||||
|
}
|
||||||
|
require.True(t, l.isLimited(ip))
|
||||||
|
|
||||||
|
// Wait for token replenishment
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
l.recordFailure("10.0.0.1")
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
require.Len(t, l.limiters, 1)
|
||||||
|
// Backdate the entry so it looks stale
|
||||||
|
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
l.cleanup()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
l.recordFailure("10.0.0.1")
|
||||||
|
l.recordFailure("10.0.0.2")
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
// Only backdate one entry
|
||||||
|
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
l.cleanup()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
||||||
|
assert.Contains(t, l.limiters, "10.0.0.2")
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
381
management/internals/shared/grpc/proxy_group_access_test.go
Normal file
381
management/internals/shared/grpc/proxy_group_access_test.go
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockReverseProxyManager struct {
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.proxiesByAccount[accountID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return []*reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockUsersManager struct {
|
||||||
|
users map[string]*types.User
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
user, ok := m.users[userID]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("user not found")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
userID string
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
users map[string]*types.User
|
||||||
|
proxyErr error
|
||||||
|
userErr error
|
||||||
|
expectErr bool
|
||||||
|
expectErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user not found",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "unknown-user",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "user not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy not found in user's account",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "service not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy exists in different account - not accessible",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "service not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no bearer auth configured - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bearer auth disabled - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user not in allowed groups",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "not in allowed groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in one of the allowed groups - allow access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in all allowed groups - allow access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy manager error",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: nil,
|
||||||
|
proxyErr: errors.New("database error"),
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "get account services",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies in account - finds correct one",
|
||||||
|
domain: "app2.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {
|
||||||
|
{Domain: "app1.example.com", AccountID: "account1"},
|
||||||
|
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
||||||
|
{Domain: "app3.example.com", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
reverseProxyManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
|
err: tt.proxyErr,
|
||||||
|
},
|
||||||
|
usersManager: &mockUsersManager{
|
||||||
|
users: tt.users,
|
||||||
|
err: tt.userErr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.expectErrMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
domain string
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
err error
|
||||||
|
expectProxy bool
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "proxy found",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {
|
||||||
|
{Domain: "other.example.com", AccountID: "account1"},
|
||||||
|
{Domain: "app.example.com", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectProxy: true,
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy not found in account",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "unknown.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
|
},
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty proxy list for account",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manager error",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: nil,
|
||||||
|
err: errors.New("database error"),
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
reverseProxyManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
|
err: tt.err,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, proxy)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, proxy)
|
||||||
|
assert.Equal(t, tt.domain, proxy.Domain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
232
management/internals/shared/grpc/proxy_test.go
Normal file
232
management/internals/shared/grpc/proxy_test.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||||
|
// and returns the channel where messages will be received.
|
||||||
|
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
||||||
|
ch := make(chan *proto.ProxyMapping, 10)
|
||||||
|
conn := &proxyConnection{
|
||||||
|
proxyID: proxyID,
|
||||||
|
address: clusterAddr,
|
||||||
|
sendChan: ch,
|
||||||
|
}
|
||||||
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
|
||||||
|
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||||
|
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||||
|
select {
|
||||||
|
case msg := <-ch:
|
||||||
|
return msg
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
const cluster = "proxy.example.com"
|
||||||
|
const numProxies = 3
|
||||||
|
|
||||||
|
channels := make([]chan *proto.ProxyMapping, numProxies)
|
||||||
|
for i := range numProxies {
|
||||||
|
id := "proxy-" + string(rune('a'+i))
|
||||||
|
channels[i] = registerFakeProxy(s, id, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Path: []*proto.PathMapping{
|
||||||
|
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
|
||||||
|
tokens := make([]string, numProxies)
|
||||||
|
for i, ch := range channels {
|
||||||
|
msg := drainChannel(ch)
|
||||||
|
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||||
|
assert.Equal(t, update.Domain, msg.Domain)
|
||||||
|
assert.Equal(t, update.Id, msg.Id)
|
||||||
|
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||||
|
tokens[i] = msg.AuthToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// All tokens must be unique
|
||||||
|
tokenSet := make(map[string]struct{})
|
||||||
|
for i, tok := range tokens {
|
||||||
|
_, exists := tokenSet[tok]
|
||||||
|
assert.False(t, exists, "proxy %d got duplicate token", i)
|
||||||
|
tokenSet[tok] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each token must be independently consumable
|
||||||
|
for i, tok := range tokens {
|
||||||
|
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
||||||
|
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
const cluster = "proxy.example.com"
|
||||||
|
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||||
|
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
|
||||||
|
msg1 := drainChannel(ch1)
|
||||||
|
msg2 := drainChannel(ch2)
|
||||||
|
require.NotNil(t, msg1)
|
||||||
|
require.NotNil(t, msg2)
|
||||||
|
|
||||||
|
// Delete operations should not generate tokens
|
||||||
|
assert.Empty(t, msg1.AuthToken)
|
||||||
|
assert.Empty(t, msg2.AuthToken)
|
||||||
|
|
||||||
|
// No tokens should have been created
|
||||||
|
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||||
|
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||||
|
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdate(update)
|
||||||
|
|
||||||
|
msg1 := drainChannel(ch1)
|
||||||
|
msg2 := drainChannel(ch2)
|
||||||
|
require.NotNil(t, msg1)
|
||||||
|
require.NotNil(t, msg2)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, msg1.AuthToken)
|
||||||
|
assert.NotEmpty(t, msg2.AuthToken)
|
||||||
|
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
||||||
|
|
||||||
|
// Both tokens should validate
|
||||||
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
||||||
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState creates a state using the same format as GetOIDCURL.
|
||||||
|
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||||
|
nonce := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(nonce)
|
||||||
|
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
||||||
|
|
||||||
|
payload := redirectURL + "|" + nonceB64
|
||||||
|
hmacSum := s.generateHMAC(payload)
|
||||||
|
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := "https://app.example.com/callback"
|
||||||
|
|
||||||
|
// Generate 100 states for the same redirect URL
|
||||||
|
states := make(map[string]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
state := generateState(s, redirectURL)
|
||||||
|
|
||||||
|
// State must have 3 parts: base64(url)|nonce|hmac
|
||||||
|
parts := strings.Split(state, "|")
|
||||||
|
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
||||||
|
|
||||||
|
// State must be unique
|
||||||
|
require.False(t, states[state], "state %d is a duplicate", i)
|
||||||
|
states[state] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old format had only 2 parts: base64(url)|hmac
|
||||||
|
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||||
|
|
||||||
|
_, _, err := s.ValidateState("base64url|hmac")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid state format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store with tampered HMAC
|
||||||
|
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||||
|
|
||||||
|
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid state signature")
|
||||||
|
}
|
||||||
@@ -300,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -311,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
|
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||||
@@ -404,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
// It implements a backpressure mechanism that sends the first update immediately,
|
||||||
|
// then debounces subsequent rapid updates, ensuring only the latest update is sent
|
||||||
|
// after a quiet period.
|
||||||
|
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||||
|
|
||||||
|
// Create a debouncer for this peer connection
|
||||||
|
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// condition when there are some updates
|
// condition when there are some updates
|
||||||
|
// todo set the updates channel size to 1
|
||||||
case update, open := <-updates:
|
case update, open := <-updates:
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||||
@@ -416,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
if !open {
|
if !open {
|
||||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
if debouncer.ProcessUpdate(update) {
|
||||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
// Send immediately (first update or after quiet period)
|
||||||
return err
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timer expired - quiet period reached, send pending updates if any
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
|
||||||
|
for _, pendingUpdate := range pendingUpdates {
|
||||||
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// condition when client <-> server connection has been terminated
|
// condition when client <-> server connection has been terminated
|
||||||
case <-srv.Context().Done():
|
case <-srv.Context().Done():
|
||||||
// happens when connection drops, e.g. client disconnects
|
// happens when connection drops, e.g. client disconnects
|
||||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return srv.Context().Err()
|
return srv.Context().Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -437,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||||
// then sends the encrypted message to the connected peer via the sync server.
|
// then sends the encrypted message to the connected peer via the sync server.
|
||||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
@@ -454,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||||
@@ -486,15 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||||
@@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||||
|
|||||||
103
management/internals/shared/grpc/update_debouncer.go
Normal file
103
management/internals/shared/grpc/update_debouncer.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateDebouncer implements a backpressure mechanism that:
|
||||||
|
// - Sends the first update immediately
|
||||||
|
// - Coalesces rapid subsequent network map updates (only latest matters)
|
||||||
|
// - Queues control/config updates (all must be delivered)
|
||||||
|
// - Preserves the order of messages (important for control configs between network maps)
|
||||||
|
// - Ensures pending updates are sent after a quiet period
|
||||||
|
type UpdateDebouncer struct {
|
||||||
|
debounceInterval time.Duration
|
||||||
|
timer *time.Timer
|
||||||
|
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
|
||||||
|
timerC <-chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpdateDebouncer creates a new debouncer with the specified interval
|
||||||
|
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
|
||||||
|
return &UpdateDebouncer{
|
||||||
|
debounceInterval: interval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
|
||||||
|
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
|
||||||
|
if d.timer == nil {
|
||||||
|
// No active debounce timer, signal to send immediately
|
||||||
|
// and start the debounce period
|
||||||
|
d.startTimer()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already in debounce period, accumulate this update preserving order
|
||||||
|
// Check if we should coalesce with the last pending update
|
||||||
|
if len(d.pendingUpdates) > 0 &&
|
||||||
|
update.MessageType == network_map.MessageTypeNetworkMap &&
|
||||||
|
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
|
||||||
|
// Replace the last network map with this one (coalesce consecutive network maps)
|
||||||
|
d.pendingUpdates[len(d.pendingUpdates)-1] = update
|
||||||
|
} else {
|
||||||
|
// Append to the queue (preserves order for control configs and non-consecutive network maps)
|
||||||
|
d.pendingUpdates = append(d.pendingUpdates, update)
|
||||||
|
}
|
||||||
|
d.resetTimer()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimerChannel returns the timer channel for select statements
|
||||||
|
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
|
||||||
|
if d.timer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return d.timerC
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingUpdates returns and clears all pending updates after timer expiration.
|
||||||
|
// Updates are returned in the order they were received, with consecutive network maps
|
||||||
|
// already coalesced to only the latest one.
|
||||||
|
// If there were pending updates, it restarts the timer to continue debouncing.
|
||||||
|
// If there were no pending updates, it clears the timer (true quiet period).
|
||||||
|
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
|
||||||
|
updates := d.pendingUpdates
|
||||||
|
d.pendingUpdates = nil
|
||||||
|
|
||||||
|
if len(updates) > 0 {
|
||||||
|
// There were pending updates, so updates are still coming rapidly
|
||||||
|
// Restart the timer to continue debouncing mode
|
||||||
|
if d.timer != nil {
|
||||||
|
d.timer.Reset(d.debounceInterval)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No pending updates means true quiet period - return to immediate mode
|
||||||
|
d.timer = nil
|
||||||
|
d.timerC = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the debouncer and cleans up resources
|
||||||
|
func (d *UpdateDebouncer) Stop() {
|
||||||
|
if d.timer != nil {
|
||||||
|
d.timer.Stop()
|
||||||
|
d.timer = nil
|
||||||
|
d.timerC = nil
|
||||||
|
}
|
||||||
|
d.pendingUpdates = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UpdateDebouncer) startTimer() {
|
||||||
|
d.timer = time.NewTimer(d.debounceInterval)
|
||||||
|
d.timerC = d.timer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UpdateDebouncer) resetTimer() {
|
||||||
|
d.timer.Stop()
|
||||||
|
d.timer.Reset(d.debounceInterval)
|
||||||
|
}
|
||||||
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldSend := debouncer.ProcessUpdate(update)
|
||||||
|
|
||||||
|
if !shouldSend {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if debouncer.TimerChannel() == nil {
|
||||||
|
t.Error("Timer should be started after first update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update should be sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update1) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rapid subsequent updates should be coalesced
|
||||||
|
if debouncer.ProcessUpdate(update2) {
|
||||||
|
t.Error("Second rapid update should not be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if debouncer.ProcessUpdate(update3) {
|
||||||
|
t.Error("Third rapid update should not be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update3 {
|
||||||
|
t.Error("Should get the last update (update3)")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Send second update within debounce period
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait for timer
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update2 {
|
||||||
|
t.Error("Should get the last update")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] == update1 {
|
||||||
|
t.Error("Should not get the first update")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Wait a bit, but not the full debounce period
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send second update - should reset timer
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait a bit more
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send third update - should reset timer again
|
||||||
|
debouncer.ProcessUpdate(update3)
|
||||||
|
|
||||||
|
// Now wait for the timer (should fire after last update's reset)
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update3 {
|
||||||
|
t.Error("Should get the last update (update3)")
|
||||||
|
}
|
||||||
|
// Timer should be restarted since there was a pending update
|
||||||
|
if debouncer.TimerChannel() == nil {
|
||||||
|
t.Error("Timer should be restarted after sending pending update")
|
||||||
|
}
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Second update coalesced
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait for timer to expire
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
t.Fatal("Should have pending update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After sending pending update, timer is restarted, so next update is NOT immediate
|
||||||
|
if debouncer.ProcessUpdate(update3) {
|
||||||
|
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the restarted timer and verify update3 is pending
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
finalUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
||||||
|
t.Error("Should get update3 as pending")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired for restarted timer")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send update to start timer
|
||||||
|
debouncer.ProcessUpdate(update)
|
||||||
|
|
||||||
|
// Stop should clean up
|
||||||
|
debouncer.Stop()
|
||||||
|
|
||||||
|
// Multiple stops should be safe
|
||||||
|
debouncer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate high-frequency updates
|
||||||
|
var lastUpdate *network_map.UpdateMessage
|
||||||
|
sentImmediately := 0
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
lastUpdate = update
|
||||||
|
if debouncer.ProcessUpdate(update) {
|
||||||
|
sentImmediately++
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only first update should be sent immediately
|
||||||
|
if sentImmediately != 1 {
|
||||||
|
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != lastUpdate {
|
||||||
|
t.Error("Should get the very last update")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
||||||
|
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
if !debouncer.ProcessUpdate(update) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for timer to expire with no additional updates (true quiet period)
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 0 {
|
||||||
|
t.Error("Should have no pending updates")
|
||||||
|
}
|
||||||
|
// After true quiet period, timer should be cleared
|
||||||
|
if debouncer.TimerChannel() != nil {
|
||||||
|
t.Error("Timer should be cleared after quiet period")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
updates := make([]*network_map.UpdateMessage, 5)
|
||||||
|
for i := range updates {
|
||||||
|
updates[i] = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(updates[0])
|
||||||
|
|
||||||
|
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
||||||
|
debouncer.ProcessUpdate(updates[1])
|
||||||
|
debouncer.ProcessUpdate(updates[2])
|
||||||
|
debouncer.ProcessUpdate(updates[3])
|
||||||
|
debouncer.ProcessUpdate(updates[4])
|
||||||
|
|
||||||
|
// Wait for debounce
|
||||||
|
<-debouncer.TimerChannel()
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
||||||
|
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update1) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for timer without sending any more updates (true quiet period)
|
||||||
|
<-debouncer.TimerChannel()
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) != 0 {
|
||||||
|
t.Error("Should have no pending updates during quiet period")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After true quiet period, next update should be sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update2) {
|
||||||
|
t.Error("Update after true quiet period should be sent immediately")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate continuous high-frequency updates
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
// First one sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// All others should be coalesced (not sent immediately)
|
||||||
|
if debouncer.ProcessUpdate(update) {
|
||||||
|
t.Errorf("Update %d should not be sent immediately", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit but send next update before debounce expires
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now wait for final debounce
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
t.Fatal("Should have the last update pending")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
||||||
|
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
netmapUpdate := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
tokenUpdate1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
tokenUpdate2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate)
|
||||||
|
|
||||||
|
// Send multiple control config updates - they should all be queued
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate1)
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate2)
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get both control config updates
|
||||||
|
if len(pendingUpdates) != 2 {
|
||||||
|
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// Control configs should come first
|
||||||
|
if pendingUpdates[0] != tokenUpdate1 {
|
||||||
|
t.Error("First pending update should be tokenUpdate1")
|
||||||
|
}
|
||||||
|
if pendingUpdates[1] != tokenUpdate2 {
|
||||||
|
t.Error("Second pending update should be tokenUpdate2")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
netmapUpdate1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
netmapUpdate2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
tokenUpdate := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate1)
|
||||||
|
|
||||||
|
// Send token update and network map update
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate)
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate2)
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get 2 updates in order: token, then network map
|
||||||
|
if len(pendingUpdates) != 2 {
|
||||||
|
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// Token update should come first (preserves order)
|
||||||
|
if pendingUpdates[0] != tokenUpdate {
|
||||||
|
t.Error("First pending update should be tokenUpdate")
|
||||||
|
}
|
||||||
|
// Network map update should come second
|
||||||
|
if pendingUpdates[1] != netmapUpdate2 {
|
||||||
|
t.Error("Second pending update should be netmapUpdate2")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
||||||
|
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
||||||
|
|
||||||
|
// Send first network map immediately
|
||||||
|
firstNetmap := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
if !debouncer.ProcessUpdate(firstNetmap) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 49 more network maps (will be coalesced to last one)
|
||||||
|
var lastNetmapBatch1 *network_map.UpdateMessage
|
||||||
|
for i := 1; i < 50; i++ {
|
||||||
|
lastNetmapBatch1 = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(lastNetmapBatch1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 1 control config
|
||||||
|
controlConfig := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(controlConfig)
|
||||||
|
|
||||||
|
// Send 50 more network maps (will be coalesced to last one)
|
||||||
|
var lastNetmapBatch2 *network_map.UpdateMessage
|
||||||
|
for i := 50; i < 100; i++ {
|
||||||
|
lastNetmapBatch2 = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(lastNetmapBatch2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
||||||
|
if len(pendingUpdates) != 3 {
|
||||||
|
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// First should be the last netmap from batch 1
|
||||||
|
if pendingUpdates[0] != lastNetmapBatch1 {
|
||||||
|
t.Error("First pending update should be last netmap from batch 1")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
||||||
|
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
// Second should be the control config
|
||||||
|
if pendingUpdates[1] != controlConfig {
|
||||||
|
t.Error("Second pending update should be control config")
|
||||||
|
}
|
||||||
|
// Third should be the last netmap from batch 2
|
||||||
|
if pendingUpdates[2] != lastNetmapBatch2 {
|
||||||
|
t.Error("Third pending update should be last netmap from batch 2")
|
||||||
|
}
|
||||||
|
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
||||||
|
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
304
management/internals/shared/grpc/validate_session_test.go
Normal file
304
management/internals/shared/grpc/validate_session_test.go
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type validateSessionTestSetup struct {
|
||||||
|
proxyService *ProxyServiceServer
|
||||||
|
store store.Store
|
||||||
|
cleanup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
||||||
|
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||||
|
|
||||||
|
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
||||||
|
proxyService.SetProxyManager(proxyManager)
|
||||||
|
|
||||||
|
createTestProxies(t, ctx, testStore)
|
||||||
|
|
||||||
|
return &validateSessionTestSetup{
|
||||||
|
proxyService: proxyService,
|
||||||
|
store: testStore,
|
||||||
|
cleanup: storeCleanup,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
pubKey, privKey := generateSessionKeyPair(t)
|
||||||
|
|
||||||
|
testProxy := &reverseproxy.Service{
|
||||||
|
ID: "testProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Test Proxy",
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||||
|
|
||||||
|
restrictedProxy := &reverseproxy.Service{
|
||||||
|
ID: "restrictedProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Restricted Proxy",
|
||||||
|
Domain: "restricted-proxy.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"allowedGroupId"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||||
|
t.Helper()
|
||||||
|
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserAllowed(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, resp.Valid, "User should be allowed access")
|
||||||
|
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||||
|
assert.Empty(t, resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "restricted-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||||
|
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||||
|
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "User in different account should be denied")
|
||||||
|
assert.Equal(t, "account_mismatch", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_UserNotFound(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "Non-existent user should be denied")
|
||||||
|
assert.Equal(t, "user_not_found", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_ProxyNotFound(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "unknown-proxy.example.com",
|
||||||
|
SessionToken: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
||||||
|
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_InvalidToken(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
SessionToken: "invalid-token",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid, "Invalid token should be denied")
|
||||||
|
assert.Equal(t, "invalid_token", resp.DeniedReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_MissingDomain(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
SessionToken: "some-token",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid)
|
||||||
|
assert.Contains(t, resp.DeniedReason, "missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSession_MissingToken(t *testing.T) {
|
||||||
|
setup := setupValidateSessionTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, resp.Valid)
|
||||||
|
assert.Contains(t, resp.DeniedReason, "missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
type testValidateSessionProxyManager struct {
|
||||||
|
store store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testValidateSessionUsersManager struct {
|
||||||
|
store store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||||
|
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
@@ -82,8 +83,9 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
requestBuffer *AccountRequestBuffer
|
requestBuffer *AccountRequestBuffer
|
||||||
|
|
||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
|
reverseProxyManager reverseproxy.Manager
|
||||||
|
|
||||||
// config contains the management server configuration
|
// config contains the management server configuration
|
||||||
config *nbconfig.Config
|
config *nbconfig.Config
|
||||||
@@ -113,6 +115,10 @@ type DefaultAccountManager struct {
|
|||||||
|
|
||||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||||
|
am.reverseProxyManager = serviceManager
|
||||||
|
}
|
||||||
|
|
||||||
func isUniqueConstraintError(err error) bool {
|
func isUniqueConstraintError(err error) bool {
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||||
@@ -321,6 +327,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||||
|
}
|
||||||
updateAccountPeers = true
|
updateAccountPeers = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1670,13 +1679,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
|||||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
|
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
@@ -1684,8 +1693,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
|||||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||||
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
|
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.Status.LastSeen.After(streamStartTime) {
|
||||||
|
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||||
|
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -58,7 +59,7 @@ type Manager interface {
|
|||||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||||
@@ -114,8 +115,8 @@ type Manager interface {
|
|||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
@@ -139,4 +140,5 @@ type Manager interface {
|
|||||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||||
|
SetServiceManager(serviceManager reverseproxy.Manager)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -1800,6 +1802,14 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
Address: "172.12.6.1/24",
|
Address: "172.12.6.1/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Services: []*reverseproxy.Service{
|
||||||
|
{
|
||||||
|
ID: "service1",
|
||||||
|
Name: "test-service",
|
||||||
|
AccountID: "account1",
|
||||||
|
Targets: []*reverseproxy.Target{},
|
||||||
|
},
|
||||||
|
},
|
||||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||||
}
|
}
|
||||||
account.InitOnce()
|
account.InitOnce()
|
||||||
@@ -1881,7 +1891,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||||
@@ -1952,7 +1962,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
failed := waitTimeout(wg, time.Second)
|
failed := waitTimeout(wg, time.Second)
|
||||||
@@ -1961,6 +1971,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||||
|
manager, _, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
key, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
peerPubKey := key.PublicKey().String()
|
||||||
|
|
||||||
|
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||||
|
Key: peerPubKey,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||||
|
}, false)
|
||||||
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
|
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
|
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err, "unable to get peer")
|
||||||
|
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||||
|
|
||||||
|
streamStartTime := time.Now().UTC()
|
||||||
|
|
||||||
|
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
|
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||||
|
|
||||||
|
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||||
|
|
||||||
|
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected,
|
||||||
|
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||||
|
node2SyncTime := time.Now().UTC()
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||||
|
require.NoError(t, err, "node 2 should connect peer")
|
||||||
|
|
||||||
|
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||||
|
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||||
|
|
||||||
|
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||||
|
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||||
|
require.NoError(t, err, "stale connect should not return error")
|
||||||
|
|
||||||
|
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||||
|
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||||
|
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, _, err := createManager(t)
|
manager, _, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
@@ -1983,7 +2069,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
@@ -3036,6 +3122,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
||||||
|
|
||||||
return manager, updateManager, nil
|
return manager, updateManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3176,7 +3264,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -204,6 +204,10 @@ const (
|
|||||||
UserInviteLinkRegenerated Activity = 106
|
UserInviteLinkRegenerated Activity = 106
|
||||||
UserInviteLinkDeleted Activity = 107
|
UserInviteLinkDeleted Activity = 107
|
||||||
|
|
||||||
|
ServiceCreated Activity = 108
|
||||||
|
ServiceUpdated Activity = 109
|
||||||
|
ServiceDeleted Activity = 110
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -337,6 +341,10 @@ var activityMap = map[Activity]Code{
|
|||||||
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
||||||
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
||||||
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
||||||
|
|
||||||
|
ServiceCreated: {"Service created", "service.create"},
|
||||||
|
ServiceUpdated: {"Service updated", "service.update"},
|
||||||
|
ServiceDeleted: {"Service deleted", "service.delete"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||||
permissionsManager := permissions.NewManager(manager.Store)
|
permissionsManager := permissions.NewManager(manager.Store)
|
||||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
|
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
||||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,28 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
@@ -25,6 +37,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||||
|
|
||||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
@@ -59,7 +73,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -75,6 +89,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
||||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
}
|
}
|
||||||
|
// OAuth callback for proxy authentication
|
||||||
|
if err := bypass.AddBypassPath(types.ProxyCallbackEndpointFull); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||||
@@ -137,7 +155,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
}
|
}
|
||||||
|
|
||||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||||
users.AddEndpoints(accountManager, router)
|
users.AddEndpoints(accountManager, router)
|
||||||
users.AddInvitesEndpoints(accountManager, router)
|
users.AddInvitesEndpoints(accountManager, router)
|
||||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||||
@@ -155,6 +173,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
idp.AddEndpoints(accountManager, router)
|
idp.AddEndpoints(accountManager, router)
|
||||||
instance.AddEndpoints(instanceManager, router)
|
instance.AddEndpoints(instanceManager, router)
|
||||||
instance.AddVersionEndpoint(instanceManager, router)
|
instance.AddVersionEndpoint(instanceManager, router)
|
||||||
|
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
||||||
|
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register OAuth callback handler for proxy authentication
|
||||||
|
if proxyGRPCServer != nil {
|
||||||
|
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
||||||
|
oauthHandler.RegisterEndpoints(router)
|
||||||
|
}
|
||||||
|
|
||||||
// Mount embedded IdP handler at /oauth2 path if configured
|
// Mount embedded IdP handler at /oauth2 path if configured
|
||||||
if embeddedIdpEnabled {
|
if embeddedIdpEnabled {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
@@ -26,11 +27,12 @@ import (
|
|||||||
// Handler is a handler that returns peers of the account
|
// Handler is a handler that returns peers of the account
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
networkMapController network_map.Controller
|
networkMapController network_map.Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
|
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||||
peersHandler := NewHandler(accountManager, networkMapController)
|
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||||
@@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new peers Handler
|
// NewHandler creates a new peers Handler
|
||||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
|
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
networkMapController: networkMapController,
|
networkMapController: networkMapController,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,6 +152,11 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
@@ -316,6 +324,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
||||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
|
if peer.ProxyMeta.Embedded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,13 +370,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"go.uber.org/mock/gomock"
|
ugomock "go.uber.org/mock/gomock"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
@@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := ugomock.NewController(t)
|
||||||
|
|
||||||
networkMapController := network_map.NewMockController(ctrl)
|
networkMapController := network_map.NewMockController(ctrl)
|
||||||
networkMapController.EXPECT().
|
networkMapController.EXPECT().
|
||||||
@@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
|||||||
Return("domain").
|
Return("domain").
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
|
ctrl2 := gomock.NewController(t)
|
||||||
|
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||||
|
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
return &Handler{
|
return &Handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
@@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
networkMapController: networkMapController,
|
networkMapController: networkMapController,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
208
management/server/http/handlers/proxy/auth.go
Normal file
208
management/server/http/handlers/proxy/auth.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||||
|
type AuthCallbackHandler struct {
|
||||||
|
proxyService *nbgrpc.ProxyServiceServer
|
||||||
|
rateLimiter *middleware.APIRateLimiter
|
||||||
|
trustedProxies []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||||
|
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
|
||||||
|
rateLimiterConfig := &middleware.RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 10,
|
||||||
|
Burst: 15,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthCallbackHandler{
|
||||||
|
proxyService: proxyService,
|
||||||
|
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||||
|
trustedProxies: trustedProxies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterEndpoints registers the OAuth callback endpoint.
|
||||||
|
func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
||||||
|
router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
clientIP := h.resolveClientIP(r)
|
||||||
|
if !h.rateLimiter.Allow(clientIP) {
|
||||||
|
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
|
||||||
|
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("OAuth callback state validation failed")
|
||||||
|
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL, err := url.Parse(originalURL)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to parse redirect URL")
|
||||||
|
http.Error(w, "Invalid redirect URL", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig := h.proxyService.GetOIDCConfig()
|
||||||
|
|
||||||
|
provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to create OIDC provider")
|
||||||
|
http.Error(w, "Failed to create OIDC provider", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := (&oauth2.Config{
|
||||||
|
ClientID: oidcConfig.ClientID,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: oidcConfig.CallbackURL,
|
||||||
|
}).Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(codeVerifier))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to exchange code for token")
|
||||||
|
http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
|
||||||
|
if userID == "" {
|
||||||
|
log.Error("Failed to extract user ID from OIDC token")
|
||||||
|
http.Error(w, "Failed to validate token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group validation is performed by the proxy via ValidateSession gRPC call.
|
||||||
|
// This allows the proxy to show 403 pages directly without redirect dance.
|
||||||
|
|
||||||
|
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to create session token")
|
||||||
|
redirectURL.Scheme = "https"
|
||||||
|
query := redirectURL.Query()
|
||||||
|
query.Set("error", "access_denied")
|
||||||
|
query.Set("error_description", "Service configuration error")
|
||||||
|
redirectURL.RawQuery = query.Encode()
|
||||||
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL.Scheme = "https"
|
||||||
|
|
||||||
|
query := redirectURL.Query()
|
||||||
|
query.Set("session_token", sessionToken)
|
||||||
|
redirectURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token")
|
||||||
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
|
||||||
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("No id_token in OIDC response")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: config.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warn("Failed to verify ID token")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
log.WithError(err).Warn("Failed to extract claims from ID token")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveClientIP extracts the real client IP from the request.
|
||||||
|
// When trustedProxies is non-empty and the direct peer is trusted,
|
||||||
|
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
|
||||||
|
// Otherwise it returns RemoteAddr directly.
|
||||||
|
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
|
||||||
|
remoteIP := extractHost(r.RemoteAddr)
|
||||||
|
|
||||||
|
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
xff := r.Header.Get("X-Forwarded-For")
|
||||||
|
if xff == "" {
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
|
ip := strings.TrimSpace(parts[i])
|
||||||
|
if ip == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !isTrustedProxy(ip, h.trustedProxies) {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||||
|
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||||
|
return first
|
||||||
|
}
|
||||||
|
return remoteIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractHost(remoteAddr string) string {
|
||||||
|
host, _, err := net.SplitHostPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return remoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||||
|
addr, err := netip.ParseAddr(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, prefix := range trusted {
|
||||||
|
if prefix.Contains(addr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,523 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeOIDCServer creates a minimal OIDC provider for testing.
|
||||||
|
type fakeOIDCServer struct {
|
||||||
|
server *httptest.Server
|
||||||
|
issuer string
|
||||||
|
signingKey ed25519.PrivateKey
|
||||||
|
publicKey ed25519.PublicKey
|
||||||
|
keyID string
|
||||||
|
tokenSubject string
|
||||||
|
tokenExpiry time.Duration
|
||||||
|
failExchange bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeOIDCServer() *fakeOIDCServer {
|
||||||
|
pub, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
f := &fakeOIDCServer{
|
||||||
|
signingKey: priv,
|
||||||
|
publicKey: pub,
|
||||||
|
keyID: "test-key-1",
|
||||||
|
tokenExpiry: time.Hour,
|
||||||
|
}
|
||||||
|
f.server = httptest.NewServer(f)
|
||||||
|
f.issuer = f.server.URL
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/.well-known/openid-configuration":
|
||||||
|
f.handleDiscovery(w, r)
|
||||||
|
case "/token":
|
||||||
|
f.handleToken(w, r)
|
||||||
|
case "/keys":
|
||||||
|
f.handleJWKS(w, r)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
discovery := map[string]interface{}{
|
||||||
|
"issuer": f.issuer,
|
||||||
|
"authorization_endpoint": f.issuer + "/auth",
|
||||||
|
"token_endpoint": f.issuer + "/token",
|
||||||
|
"jwks_uri": f.issuer + "/keys",
|
||||||
|
"response_types_supported": []string{
|
||||||
|
"code",
|
||||||
|
"id_token",
|
||||||
|
"token id_token",
|
||||||
|
},
|
||||||
|
"subject_types_supported": []string{"public"},
|
||||||
|
"id_token_signing_alg_values_supported": []string{"EdDSA"},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(discovery)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if f.failExchange {
|
||||||
|
http.Error(w, "invalid_grant", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "bad request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken := f.createIDToken()
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": idToken,
|
||||||
|
"refresh_token": "test-refresh-token",
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) createIDToken() string {
|
||||||
|
now := time.Now()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": f.issuer,
|
||||||
|
"sub": f.tokenSubject,
|
||||||
|
"aud": "test-client-id",
|
||||||
|
"exp": now.Add(f.tokenExpiry).Unix(),
|
||||||
|
"iat": now.Unix(),
|
||||||
|
"nbf": now.Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
token.Header["kid"] = f.keyID
|
||||||
|
signed, _ := token.SignedString(f.signingKey)
|
||||||
|
return signed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
jwks := map[string]interface{}{
|
||||||
|
"keys": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"kty": "OKP",
|
||||||
|
"crv": "Ed25519",
|
||||||
|
"kid": f.keyID,
|
||||||
|
"x": base64.RawURLEncoding.EncodeToString(f.publicKey),
|
||||||
|
"use": "sig",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDCServer) Close() {
|
||||||
|
f.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testSetup contains all test dependencies.
|
||||||
|
type testSetup struct {
|
||||||
|
store store.Store
|
||||||
|
oidcServer *fakeOIDCServer
|
||||||
|
proxyService *nbgrpc.ProxyServiceServer
|
||||||
|
handler *AuthCallbackHandler
|
||||||
|
router *mux.Router
|
||||||
|
cleanup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testAccessLogManager is a minimal mock for accesslogs.Manager.
|
||||||
|
type testAccessLogManager struct{}
|
||||||
|
|
||||||
|
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
createTestAccountsAndUsers(t, ctx, testStore)
|
||||||
|
createTestReverseProxies(t, ctx, testStore)
|
||||||
|
|
||||||
|
oidcServer := newFakeOIDCServer()
|
||||||
|
|
||||||
|
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
||||||
|
|
||||||
|
usersManager := users.NewManager(testStore)
|
||||||
|
|
||||||
|
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||||
|
Issuer: oidcServer.issuer,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
CallbackURL: "https://management.example.com/reverse-proxy/callback",
|
||||||
|
HMACKey: []byte("test-hmac-key-for-state-signing"),
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyService := nbgrpc.NewProxyServiceServer(
|
||||||
|
&testAccessLogManager{},
|
||||||
|
tokenStore,
|
||||||
|
oidcConfig,
|
||||||
|
nil,
|
||||||
|
usersManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
proxyService.SetProxyManager(&testServiceManager{store: testStore})
|
||||||
|
|
||||||
|
handler := NewAuthCallbackHandler(proxyService, nil)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
handler.RegisterEndpoints(router)
|
||||||
|
|
||||||
|
return &testSetup{
|
||||||
|
store: testStore,
|
||||||
|
oidcServer: oidcServer,
|
||||||
|
proxyService: proxyService,
|
||||||
|
handler: handler,
|
||||||
|
router: router,
|
||||||
|
cleanup: func() {
|
||||||
|
cleanup()
|
||||||
|
oidcServer.Close()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||||
|
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||||
|
|
||||||
|
testProxy := &reverseproxy.Service{
|
||||||
|
ID: "testProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Test Proxy",
|
||||||
|
Domain: "test-proxy.example.com",
|
||||||
|
Targets: []*reverseproxy.Target{{
|
||||||
|
Path: strPtr("/"),
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
TargetId: "peer1",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
Enabled: true,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"allowedGroupId"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||||
|
|
||||||
|
restrictedProxy := &reverseproxy.Service{
|
||||||
|
ID: "restrictedProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Restricted Proxy",
|
||||||
|
Domain: "restricted-proxy.example.com",
|
||||||
|
Targets: []*reverseproxy.Target{{
|
||||||
|
Path: strPtr("/"),
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
TargetId: "peer1",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
Enabled: true,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"restrictedGroupId"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||||
|
|
||||||
|
noAuthProxy := &reverseproxy.Service{
|
||||||
|
ID: "noAuthProxyId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "No Auth Proxy",
|
||||||
|
Domain: "no-auth-proxy.example.com",
|
||||||
|
Targets: []*reverseproxy.Target{{
|
||||||
|
Path: strPtr("/"),
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
TargetId: "peer1",
|
||||||
|
TargetType: "peer",
|
||||||
|
Enabled: true,
|
||||||
|
}},
|
||||||
|
Enabled: true,
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionPrivateKey: privKey,
|
||||||
|
SessionPublicKey: pubKey,
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, noAuthProxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
func strPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
testAccount := &types.Account{
|
||||||
|
Id: "testAccountId",
|
||||||
|
Domain: "test.com",
|
||||||
|
DomainCategory: "private",
|
||||||
|
IsDomainPrimaryAccount: true,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
|
||||||
|
|
||||||
|
allowedGroup := &types.Group{
|
||||||
|
ID: "allowedGroupId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Name: "Allowed Group",
|
||||||
|
Issued: "api",
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
|
||||||
|
|
||||||
|
allowedUser := &types.User{
|
||||||
|
Id: "allowedUserId",
|
||||||
|
AccountID: "testAccountId",
|
||||||
|
Role: types.UserRoleUser,
|
||||||
|
AutoGroups: []string{"allowedGroupId"},
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Issued: "api",
|
||||||
|
}
|
||||||
|
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
|
||||||
|
}
|
||||||
|
|
||||||
|
// testServiceManager is a minimal implementation for testing.
|
||||||
|
type testServiceManager struct {
|
||||||
|
store store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{
|
||||||
|
RedirectUrl: redirectURL,
|
||||||
|
AccountId: "testAccountId",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(resp.Url)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return parsedURL.Query().Get("state")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
|
||||||
|
location := rec.Header().Get("Location")
|
||||||
|
require.NotEmpty(t, location)
|
||||||
|
|
||||||
|
parsedLocation, err := url.Parse(location)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "test-proxy.example.com", parsedLocation.Host)
|
||||||
|
require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token")
|
||||||
|
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_ProxyNotFound(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||||
|
|
||||||
|
require.NoError(t, setup.store.DeleteService(context.Background(), "testAccountId", "testProxyId"))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
|
||||||
|
location := rec.Header().Get("Location")
|
||||||
|
parsedLocation, err := url.Parse(location)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_InvalidToken(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.failExchange = true
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "Failed to exchange code")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_ExpiredToken(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
setup.oidcServer.tokenSubject = "allowedUserId"
|
||||||
|
setup.oidcServer.tokenExpiry = -time.Hour
|
||||||
|
|
||||||
|
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "Failed to validate token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_InvalidState(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "Invalid state")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallback_MissingState(t *testing.T) {
|
||||||
|
setup := setupAuthCallbackTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
setup.router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
185
management/server/http/handlers/proxy/auth_test.go
Normal file
185
management/server/http/handlers/proxy/auth_test.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
|
||||||
|
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
|
||||||
|
t.Run("allows requests under limit", func(t *testing.T) {
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
allowed := handler.rateLimiter.Allow("192.168.1.100")
|
||||||
|
assert.True(t, allowed, "Request %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocks requests over limit", func(t *testing.T) {
|
||||||
|
handler.rateLimiter.Reset("192.168.1.200")
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow("192.168.1.200")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := handler.rateLimiter.Allow("192.168.1.200")
|
||||||
|
assert.False(t, allowed, "Request over limit should be blocked")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different IPs have separate limits", func(t *testing.T) {
|
||||||
|
ip1 := "192.168.1.201"
|
||||||
|
ip2 := "192.168.1.202"
|
||||||
|
|
||||||
|
handler.rateLimiter.Reset(ip1)
|
||||||
|
handler.rateLimiter.Reset(ip2)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow(ip1)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, handler.rateLimiter.Allow(ip1), "IP1 should be blocked")
|
||||||
|
|
||||||
|
assert.True(t, handler.rateLimiter.Allow(ip2), "IP2 should be allowed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
|
||||||
|
testIP := "10.0.0.50"
|
||||||
|
|
||||||
|
handler.rateLimiter.Reset(testIP)
|
||||||
|
|
||||||
|
t.Run("returns 429 when rate limited", func(t *testing.T) {
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow(testIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
|
||||||
|
req.RemoteAddr = testIP + ":12345"
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.handleCallback(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should return 429 status code")
|
||||||
|
assert.Contains(t, rr.Body.String(), "Too many requests", "Should contain rate limit message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveClientIP(t *testing.T) {
|
||||||
|
trusted := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/12"),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
trustedProxy []netip.Prefix
|
||||||
|
expectedIP string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no trusted proxies returns RemoteAddr",
|
||||||
|
remoteAddr: "203.0.113.50:9999",
|
||||||
|
xForwardedFor: "1.2.3.4",
|
||||||
|
trustedProxy: nil,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "untrusted RemoteAddr ignores XFF",
|
||||||
|
remoteAddr: "203.0.113.50:9999",
|
||||||
|
xForwardedFor: "1.2.3.4, 10.0.0.1",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted RemoteAddr with single client in XFF",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "203.0.113.50",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted RemoteAddr walks past trusted entries in XFF",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "203.0.113.50, 10.0.0.2, 172.16.0.5",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "10.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all XFF IPs trusted returns leftmost",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "10.0.0.2, 172.16.0.1, 10.0.0.3",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "10.0.0.2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "XFF with whitespace",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: " 203.0.113.50 , 10.0.0.2 ",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-hop with mixed trust",
|
||||||
|
remoteAddr: "10.0.0.1:5000",
|
||||||
|
xForwardedFor: "8.8.8.8, 203.0.113.50, 172.16.0.1",
|
||||||
|
trustedProxy: trusted,
|
||||||
|
expectedIP: "203.0.113.50",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr without port",
|
||||||
|
remoteAddr: "192.168.1.100",
|
||||||
|
expectedIP: "192.168.1.100",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, tt.trustedProxy)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := handler.resolveClientIP(req)
|
||||||
|
assert.Equal(t, tt.expectedIP, ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
|
||||||
|
|
||||||
|
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
|
||||||
|
|
||||||
|
testIP := "192.168.1.250"
|
||||||
|
handler.rateLimiter.Reset(testIP)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
allowed := handler.rateLimiter.Allow(testIP)
|
||||||
|
assert.True(t, allowed, "Should allow request %d within burst limit", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := handler.rateLimiter.Allow(testIP)
|
||||||
|
assert.False(t, allowed, "Should block request that exceeds burst limit")
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user