mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-10 18:59:55 +00:00
Compare commits
279 Commits
v0.64.2
...
dn-reverse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76fb153d76 | ||
|
|
eee4d75932 | ||
|
|
62b8875f67 | ||
|
|
47a5478964 | ||
|
|
9922d6f953 | ||
|
|
f9bab22f61 | ||
|
|
3d8fdb7a89 | ||
|
|
fb10153ab8 | ||
|
|
57d3ee5aac | ||
|
|
cfdfdecc14 | ||
|
|
b00babb8b1 | ||
|
|
ac995bae6d | ||
|
|
41a5509ce0 | ||
|
|
db5e26db94 | ||
|
|
fe975fb834 | ||
|
|
e368d2995b | ||
|
|
a3241d8376 | ||
|
|
6dfc5772ba | ||
|
|
f70925178c | ||
|
|
9554934b92 | ||
|
|
7fdb824a37 | ||
|
|
412407adc0 | ||
|
|
e0874d7de7 | ||
|
|
8df1536cbb | ||
|
|
fcbacc62ec | ||
|
|
ee2ae45653 | ||
|
|
3bc8cbb13f | ||
|
|
bf7bdf6c4f | ||
|
|
6f2f0f9ae4 | ||
|
|
c37ebc6fb3 | ||
|
|
23abb5743c | ||
|
|
0a895ffc22 | ||
|
|
b87aa0bc15 | ||
|
|
69d4b5d821 | ||
|
|
f1a65d732d | ||
|
|
a3c0ea3e71 | ||
|
|
abaf061c2a | ||
|
|
e531fb54b1 | ||
|
|
5fcfed5b16 | ||
|
|
b81837a364 | ||
|
|
5f43449f67 | ||
|
|
6796601aa6 | ||
|
|
1fc25c301b | ||
|
|
08ae281b2d | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
bd47f44c63 | ||
|
|
381260911b | ||
|
|
38db42e7d6 | ||
|
|
5d606d909d | ||
|
|
d689718b50 | ||
|
|
54a73c6649 | ||
|
|
418377842e | ||
|
|
15ef56e03d | ||
|
|
917035f8e8 | ||
|
|
963e3f5457 | ||
|
|
e20b969188 | ||
|
|
1c7059ee67 | ||
|
|
22a3365658 | ||
|
|
2de1949018 | ||
|
|
08ab1e3478 | ||
|
|
ebb1f4007d | ||
|
|
acb53ece93 | ||
|
|
e020950cfd | ||
|
|
9dba262a20 | ||
|
|
5bcdf36377 | ||
|
|
1ffe8deb10 | ||
|
|
d069145bd1 | ||
|
|
f3493ee042 | ||
|
|
b782ac6f56 | ||
|
|
bf48044e5c | ||
|
|
fb4cc37a4a | ||
|
|
55b8d89a79 | ||
|
|
6968a32a5a | ||
|
|
cfe6753349 | ||
|
|
5ae15b3af3 | ||
|
|
b79adb706c | ||
|
|
f22497d5da | ||
|
|
95d672c9df | ||
|
|
7d08a609e6 | ||
|
|
eea6120cd0 | ||
|
|
fc88399c23 | ||
|
|
0cb02bd906 | ||
|
|
08d3867f41 | ||
|
|
b16d63643c | ||
|
|
940d01bdea | ||
|
|
ba9158d159 | ||
|
|
ca9a7e11ef | ||
|
|
a803f47685 | ||
|
|
79fed32f01 | ||
|
|
6b00bb0a66 | ||
|
|
e2adef1eea | ||
|
|
9e5fa11792 | ||
|
|
1ff75acb31 | ||
|
|
1754160686 | ||
|
|
423f6266fb | ||
|
|
16d1b4a14a | ||
|
|
7c14056faf | ||
|
|
62e37dc2e2 | ||
|
|
6a08695ee8 | ||
|
|
9a67a8e427 | ||
|
|
73aa0785ba | ||
|
|
53c1016a8e | ||
|
|
fd442138e6 | ||
|
|
be5f30225a | ||
|
|
7467e9fb8c | ||
|
|
2390c2e46e | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
778c223176 | ||
|
|
36cd0dd85c | ||
|
|
09a1d5a02d | ||
|
|
7c996ac9b5 | ||
|
|
cf9fd5d960 | ||
|
|
1c5ab7cb8f | ||
|
|
aaad3b25a7 | ||
|
|
9904235a2f | ||
|
|
780e9f57a5 | ||
|
|
a8db73285b | ||
|
|
3b43c00d12 | ||
|
|
2f390e1794 | ||
|
|
3630ebb3ae | ||
|
|
260c46df04 | ||
|
|
7f11e3205d | ||
|
|
1c8f92a96f | ||
|
|
7b6294b624 | ||
|
|
156d0b1fef | ||
|
|
2cf00dba58 | ||
|
|
d2a7f3ae36 | ||
|
|
6a64d4e4dd | ||
|
|
51e63c246b | ||
|
|
99e6b1eda4 | ||
|
|
dc26a5a436 | ||
|
|
3883b2fb41 | ||
|
|
ed58659a01 | ||
|
|
5190923c70 | ||
|
|
7c647dd160 | ||
|
|
07e59b2708 | ||
|
|
0a3a9f977d | ||
|
|
2f263bf7e6 | ||
|
|
f65f4fc280 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
adbd7ab4c3 | ||
|
|
0419834482 | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
f797d2d9cb | ||
|
|
5ae7efe8f7 | ||
|
|
d6e35bd0fe | ||
|
|
0e00f1c8f7 | ||
|
|
1b96648d4d | ||
|
|
4433f44a12 | ||
|
|
7504e718d7 | ||
|
|
9b0387e7ee | ||
|
|
d2f9653cea | ||
|
|
5ccce1ab3f | ||
|
|
e366fe340e | ||
|
|
b01809f8e3 | ||
|
|
790ef39187 | ||
|
|
3af16cf333 | ||
|
|
194a986926 | ||
|
|
d09c69f303 | ||
|
|
096d4ac529 | ||
|
|
f7732557fa | ||
|
|
8fafde614a | ||
|
|
694ae13418 | ||
|
|
b5b7dd4f53 | ||
|
|
476785b122 | ||
|
|
907677f835 | ||
|
|
7d844b9410 | ||
|
|
eeabc64a73 | ||
|
|
5da2b0fdcc | ||
|
|
a0005a604e | ||
|
|
a89bb807a6 | ||
|
|
28f3354ffa | ||
|
|
562923c600 | ||
|
|
d488f58311 | ||
|
|
0dd0c67b3b | ||
|
|
ca33849f31 | ||
|
|
18cd0f1480 | ||
|
|
b02982f6b1 | ||
|
|
4d89ae27ef | ||
|
|
733ea77c5c | ||
|
|
92f72bfce6 | ||
|
|
6fdc00ff41 | ||
|
|
bffb25bea7 | ||
|
|
3af4543e80 | ||
|
|
146774860b | ||
|
|
5243481316 | ||
|
|
76a39c1dcb | ||
|
|
02ce918114 | ||
|
|
30cfc22cb6 | ||
|
|
3168afbfcb | ||
|
|
a73ee47557 | ||
|
|
fa6ff005f2 | ||
|
|
095379fa60 | ||
|
|
30572fe1b8 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 | ||
|
|
3a6f364b03 | ||
|
|
5345d716ee | ||
|
|
f882c36e0a | ||
|
|
0c990ab662 | ||
|
|
101c813e98 | ||
|
|
e95cfa1a00 | ||
|
|
5333e55a81 | ||
|
|
0d480071b6 | ||
|
|
8e0b7b6c25 | ||
|
|
81c11df103 | ||
|
|
f204da0d68 | ||
|
|
7d74904d62 | ||
|
|
760ac5e07d | ||
|
|
f74bc48d16 | ||
|
|
4352228797 | ||
|
|
0169e4540f | ||
|
|
74c770609c | ||
|
|
f4ca36ed7e | ||
|
|
c86da92fc6 | ||
|
|
3f0c577456 | ||
|
|
717da8c7b7 | ||
|
|
a0a61d4f47 | ||
|
|
cead3f38ee | ||
|
|
5b1fced872 | ||
|
|
c98dcf5ef9 | ||
|
|
57cb6bfccb | ||
|
|
95bf97dc3c | ||
|
|
3d116c9d33 | ||
|
|
b55262d4a2 | ||
|
|
a9ce9f8d5a | ||
|
|
10b981a855 | ||
|
|
7700b4333d | ||
|
|
7d0131111e | ||
|
|
1daea35e4b | ||
|
|
f97544af0d | ||
|
|
231e80cc15 | ||
|
|
a4c1362bff | ||
|
|
b611d4a751 | ||
|
|
2248ff392f | ||
|
|
2c9decfa55 | ||
|
|
3c5ac17e2f | ||
|
|
ae42bbb898 | ||
|
|
b86722394b | ||
|
|
a103f69767 | ||
|
|
73fbb3fc62 | ||
|
|
7b3523e25e | ||
|
|
6e4e1386e7 | ||
|
|
671e9af6eb | ||
|
|
50f42caf94 | ||
|
|
b7eeefc102 | ||
|
|
8dd22f3a4f | ||
|
|
4b89427447 | ||
|
|
b71e2860cf | ||
|
|
160b27bc60 | ||
|
|
c084386b88 | ||
|
|
6889047350 | ||
|
|
245bbb4acf | ||
|
|
2b2fc02d83 | ||
|
|
703ef29199 | ||
|
|
b0b60b938a | ||
|
|
e3a026bf1c | ||
|
|
94503465ee | ||
|
|
8d959b0abc | ||
|
|
1d8390b935 | ||
|
|
2851e38a1f | ||
|
|
51261fe7a9 | ||
|
|
304321d019 | ||
|
|
f8c3295645 | ||
|
|
183619d1e1 | ||
|
|
3b832d1f21 | ||
|
|
fcb849698f | ||
|
|
7527e0ebdb | ||
|
|
ed5f98da5b | ||
|
|
12b38e25da | ||
|
|
626e892e3b |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
*.key
|
||||||
|
*.crt
|
||||||
|
*.p12
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./signal/...
|
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
51
.github/workflows/golang-test-linux.yml
vendored
51
.github/workflows/golang-test-linux.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +204,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,6 +261,53 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
|
test_proxy:
|
||||||
|
name: "Proxy / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.0"
|
SIGN_PIPE_VER: "v0.1.1"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
|
!proxy/web/dist/
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### Self-Host NetBird (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
|
|||||||
@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
}
|
}
|
||||||
defer authClient.Close()
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
err, isAuthError := authClient.Login(ctx, "", "")
|
return fmt.Errorf("check login required: %v", err)
|
||||||
if isAuthError {
|
|
||||||
needsLogin = true
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("login check failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
|
|||||||
@@ -31,6 +31,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
|
PeerStatusConnected = peer.StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -69,6 +77,10 @@ type Options struct {
|
|||||||
StatePath string
|
StatePath string
|
||||||
// DisableClientRoutes disables the client routes
|
// DisableClientRoutes disables the client routes
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
|
// BlockInbound blocks all inbound connections from peers
|
||||||
|
BlockInbound bool
|
||||||
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
|
WireguardPort *int
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -137,6 +149,8 @@ func New(opts Options) (*Client, error) {
|
|||||||
PreSharedKey: &opts.PreSharedKey,
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
WireguardPort: opts.WireguardPort,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -156,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
|
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,6 +192,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
|
||||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create auth client: %w", err)
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
@@ -186,10 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
|
||||||
c.recorder = recorder
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -342,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
recorder := c.recorder
|
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if recorder == nil {
|
|
||||||
return peer.FullStatus{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -357,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetFullStatus(), nil
|
return c.recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nftRule.Handle == 0 {
|
if nftRule.Handle == 0 {
|
||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback ipset counter
|
r.rollbackRules(pair)
|
||||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||||
|
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||||
|
keys := []string{
|
||||||
|
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
rule, ok := r.rules[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
}
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||||
|
// counters will be off until the next successful removal or refresh cycle.
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback set counter
|
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||||
func (r *router) refreshRulesMap() error {
|
func (r *router) refreshRulesMap() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
newRules := make(map[string]*nftables.Rule)
|
||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list rules: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||||
|
// preserve existing entries for this chain since we can't verify their state
|
||||||
|
for k, v := range r.rules {
|
||||||
|
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||||
|
newRules[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
r.rules[string(rule.UserData)] = rule
|
newRules[string(rule.UserData)] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
r.rules = newRules
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var needsFlush bool
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
if dnatRule.Handle == 0 {
|
||||||
|
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(masqRule); err != nil {
|
if masqRule.Handle == 0 {
|
||||||
|
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if needsFlush {
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
rule, exists := r.rules[ruleID]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
return nil
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Add a real rule to the kernel
|
||||||
|
ruleKey, err := r.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
firewall.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&firewall.Port{Values: []uint16{80}},
|
||||||
|
firewall.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
|
staleKey := "stale-rule-that-does-not-exist"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||||
|
|
||||||
|
err = r.refreshRulesMap()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||||
|
|
||||||
|
realRule, ok := r.rules[ruleKey.ID()]
|
||||||
|
assert.True(t, ok, "real rule should still exist after refresh")
|
||||||
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0
|
||||||
|
staleKey := "stale-route-rule"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule should not return an error for stale handles
|
||||||
|
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||||
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
pair := firewall.RouterPair{
|
||||||
|
ID: "staletest",
|
||||||
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
rtr := manager.router
|
||||||
|
|
||||||
|
// First add succeeds
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Corrupt the handle to simulate stale state
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||||
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the same rule again should succeed despite stale handles
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||||
|
|
||||||
|
// Verify rules exist in kernel
|
||||||
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
found := 0
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||||
|
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||||
|
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||||
|
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||||
|
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||||
|
if t.tombstone.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && !conn.IsSupersededBy(flags) {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsTombstone() {
|
if !exists || conn.IsSupersededBy(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||||
|
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||||
|
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||||
|
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||||
|
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||||
|
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and gracefully close a connection (server-initiated close)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Server sends FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Client sends FIN-ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Server sends final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Connection should be tombstoned
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn, "old connection should still be in map")
|
||||||
|
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||||
|
|
||||||
|
// Now reuse the same port for a new connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// The old tombstoned entry should be replaced with a new one
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
// SYN-ACK for the new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
|
||||||
|
// Data transfer should work
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||||
|
require.True(t, valid, "data should be allowed on new connection")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and RST a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||||
|
|
||||||
|
// Reuse the same port
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := srcIP
|
||||||
|
serverIP := dstIP
|
||||||
|
clientPort := srcPort
|
||||||
|
serverPort := dstPort
|
||||||
|
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||||
|
|
||||||
|
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Server-initiated close to reach Closed/tombstoned:
|
||||||
|
// Server FIN (opposite dir) → CloseWait
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
// Client FIN-ACK (same dir as conn) → LastAck
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// New inbound connection on same ports
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
|
||||||
|
// Complete handshake: server SYN-ACK, then client ACK
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// Late ACK should be rejected (tombstoned)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||||
|
|
||||||
|
// Late outbound ACK should not create a new connection (not a SYN)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Active close: client (outbound initiator) sends FIN first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Server ACKs the FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Server sends its own FIN
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||||
|
|
||||||
|
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||||
|
|
||||||
|
// SYN-ACK for new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish outbound connection and close via active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||||
|
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||||
|
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||||
|
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||||
|
|
||||||
|
// Simulate what the filter does next: TrackInbound via the normal path
|
||||||
|
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||||
|
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||||
|
newConn := tracker.connections[invertedKey]
|
||||||
|
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,11 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
@@ -24,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -89,6 +93,7 @@ type Manager struct {
|
|||||||
incomingDenyRules map[netip.Addr]RuleSet
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
|||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||||
|
return existingRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: string(ruleKey),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.routeRulesMap[ruleKey] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleKey := nbid.RuleID(rule.ID())
|
||||||
|
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == string(ruleKey)
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
delete(m.routeRulesMap, ruleKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
|
// Must be called with m.mutex held.
|
||||||
|
func (m *Manager) resetState() {
|
||||||
|
maps.Clear(m.outgoingRules)
|
||||||
|
maps.Clear(m.incomingDenyRules)
|
||||||
|
maps.Clear(m.incomingRules)
|
||||||
|
maps.Clear(m.routeRulesMap)
|
||||||
|
m.routeRules = m.routeRules[:0]
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
|
|||||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||||
|
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||||
|
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.1.0/24"),
|
||||||
|
netip.MustParsePrefix("100.64.2.0/24"),
|
||||||
|
}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add rule first time
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
// Add the same rule again
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
// These should be the same (idempotent) like nftables/iptables implementations
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||||
|
// different parameters get distinct IDs.
|
||||||
|
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
|
// Add first rule
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add different rule (different destination)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-2"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Different rules should have different IDs")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||||
|
// rule during a network map update does not disrupt existing traffic.
|
||||||
|
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
|
// Re-add same rule (simulates network map update)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||||
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
|
if rule1.ID() != rule2.ID() {
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.True(t, passAfter,
|
||||||
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
|
// returns the same rule without duplicating.
|
||||||
|
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call blockInvalidRouted directly multiple times
|
||||||
|
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule3)
|
||||||
|
|
||||||
|
// All should return the same rule
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||||
|
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||||
|
|
||||||
|
// Should have exactly 1 route rule
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||||
|
|
||||||
|
// Verify the rule blocks traffic to the WG network
|
||||||
|
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||||
|
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||||
|
// EnableRouting multiple times (as happens on each route update) does not
|
||||||
|
// accumulate duplicate block rules in the routeRules slice.
|
||||||
|
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount,
|
||||||
|
"Repeated EnableRouting should not accumulate block rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||||
|
// rule multiple times does not create duplicate entries.
|
||||||
|
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Simulate 5 network map updates with the same route rule
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||||
|
// after adding it multiple times works correctly.
|
||||||
|
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add same rule twice
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
|
// Delete using first reference
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify traffic no longer passes
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestManager(t *testing.T) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
|
}
|
||||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
|
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||||
|
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Add multiple deny rules for different ports
|
||||||
|
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
|
// Delete the first deny rule
|
||||||
|
err = m.DeletePeerRule(rule1[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount = len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
|
// Delete the second deny rule
|
||||||
|
err = m.DeletePeerRule(rule2[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, exists := m.incomingDenyRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
|
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||||
|
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
// Add a deny rule
|
||||||
|
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add an allow rule
|
||||||
|
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete them (simulating ACL manager cleanup)
|
||||||
|
for _, r := range rules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
for _, r := range allowRules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
|
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||||
|
// IP are stored in separate maps and don't interfere with each other.
|
||||||
|
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
|
// Add allow rule for port 80
|
||||||
|
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add deny rule for port 22
|
||||||
|
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
m.mutex.RLock()
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
|
// Delete allow rule should not affect deny rule
|
||||||
|
err = m.DeletePeerRule(allowRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
|
// Delete deny rule
|
||||||
|
err = m.DeletePeerRule(denyRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, denyExists := m.incomingDenyRules[addr]
|
||||||
|
_, allowExists := m.incomingRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
|
require.False(t, allowExists, "Allow rules should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,9 +18,18 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
logChannelSize = 1000
|
defaultLogChanSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getLogChannelSize() int {
|
||||||
|
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultLogChanSize
|
||||||
|
}
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -69,7 +80,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, logChannelSize),
|
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
host, portStr, err := net.SplitHostPort(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse endpoint: %v", err)
|
log.Errorf("failed to parse endpoint: %v", err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
|||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDeviceFilter constructor function
|
// newDeviceFilter constructor function
|
||||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying tun device exactly once.
|
||||||
|
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||||
|
// and multiple code paths can trigger Close on the same device.
|
||||||
|
func (d *FilteredDevice) Close() error {
|
||||||
|
var err error
|
||||||
|
d.closeOnce.Do(func() {
|
||||||
|
err = d.Device.Close()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Read wraps read method with filtering feature
|
// Read wraps read method with filtering feature
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
log.Debugf("failed to close tun device: %v", cErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nbnetstack.IsEnabled() {
|
||||||
|
return errors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.waitUntilRemoved(); err != nil {
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
if err := w.Destroy(); err != nil {
|
if err := w.Destroy(); err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nsTunDev, tunNet, nil
|
return t.tundev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -26,16 +24,6 @@ const (
|
|||||||
loopbackAddr = "127.0.0.1"
|
loopbackAddr = "127.0.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
|
||||||
localHostNetIPv6 = net.ParseIP("::1")
|
|
||||||
|
|
||||||
serializeOpts = gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
@@ -253,63 +241,3 @@ generatePort:
|
|||||||
}
|
}
|
||||||
return p.lastUsedPort, nil
|
return p.lastUsedPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
|
||||||
|
|
||||||
var ipH gopacket.SerializableLayer
|
|
||||||
var networkLayer gopacket.NetworkLayer
|
|
||||||
var dstIP net.IP
|
|
||||||
var rawConn net.PacketConn
|
|
||||||
|
|
||||||
if endpointAddr.IP.To4() != nil {
|
|
||||||
// IPv4 path
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
DstIP: localHostNetIPv4,
|
|
||||||
SrcIP: endpointAddr.IP,
|
|
||||||
Version: 4,
|
|
||||||
TTL: 64,
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
ipH = ipv4
|
|
||||||
networkLayer = ipv4
|
|
||||||
dstIP = localHostNetIPv4
|
|
||||||
rawConn = p.rawConnIPv4
|
|
||||||
} else {
|
|
||||||
// IPv6 path
|
|
||||||
if p.rawConnIPv6 == nil {
|
|
||||||
return fmt.Errorf("IPv6 raw socket not available")
|
|
||||||
}
|
|
||||||
ipv6 := &layers.IPv6{
|
|
||||||
DstIP: localHostNetIPv6,
|
|
||||||
SrcIP: endpointAddr.IP,
|
|
||||||
Version: 6,
|
|
||||||
HopLimit: 64,
|
|
||||||
NextHeader: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
ipH = ipv6
|
|
||||||
networkLayer = ipv6
|
|
||||||
dstIP = localHostNetIPv6
|
|
||||||
rawConn = p.rawConnIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
udpH := &layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
|
||||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
|
||||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
layerBuffer := gopacket.NewSerializeBuffer()
|
|
||||||
payload := gopacket.Payload(data)
|
|
||||||
|
|
||||||
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
|
|
||||||
return fmt.Errorf("serialize layers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
|
|
||||||
return fmt.Errorf("write to raw conn: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,12 +10,89 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||||
|
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||||
|
|
||||||
|
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||||
|
localHostNetIPv6 = net.ParseIP("::1")
|
||||||
|
|
||||||
|
serializeOpts = gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||||
|
type PacketHeaders struct {
|
||||||
|
ipH gopacket.SerializableLayer
|
||||||
|
udpH *layers.UDP
|
||||||
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
localHostAddr net.IP
|
||||||
|
isIPv4 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||||
|
var ipH gopacket.SerializableLayer
|
||||||
|
var networkLayer gopacket.NetworkLayer
|
||||||
|
var localHostAddr net.IP
|
||||||
|
var isIPv4 bool
|
||||||
|
|
||||||
|
// Check if source address is IPv4 or IPv6
|
||||||
|
if endpoint.IP.To4() != nil {
|
||||||
|
// IPv4 path
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
DstIP: localHostNetIPv4,
|
||||||
|
SrcIP: endpoint.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv4
|
||||||
|
networkLayer = ipv4
|
||||||
|
localHostAddr = localHostNetIPv4
|
||||||
|
isIPv4 = true
|
||||||
|
} else {
|
||||||
|
// IPv6 path
|
||||||
|
ipv6 := &layers.IPv6{
|
||||||
|
DstIP: localHostNetIPv6,
|
||||||
|
SrcIP: endpoint.IP,
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv6
|
||||||
|
networkLayer = ipv6
|
||||||
|
localHostAddr = localHostNetIPv6
|
||||||
|
isIPv4 = false
|
||||||
|
}
|
||||||
|
|
||||||
|
udpH := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(endpoint.Port),
|
||||||
|
DstPort: layers.UDPPort(localWGListenPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PacketHeaders{
|
||||||
|
ipH: ipH,
|
||||||
|
udpH: udpH,
|
||||||
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
localHostAddr: localHostAddr,
|
||||||
|
isIPv4: isIPv4,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
type ProxyWrapper struct {
|
type ProxyWrapper struct {
|
||||||
wgeBPFProxy *WGEBPFProxy
|
wgeBPFProxy *WGEBPFProxy
|
||||||
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgRelayedEndpointAddr *net.UDPAddr
|
wgRelayedEndpointAddr *net.UDPAddr
|
||||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
headers *PacketHeaders
|
||||||
|
headerCurrentUsed *PacketHeaders
|
||||||
|
rawConn net.PacketConn
|
||||||
|
|
||||||
paused bool
|
paused bool
|
||||||
pausedCond *sync.Cond
|
pausedCond *sync.Cond
|
||||||
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
|||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add turn conn: %w", err)
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create packet sender: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if required raw connection is available
|
||||||
|
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||||
|
return errIPv6ConnNotAvailable
|
||||||
|
}
|
||||||
|
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||||
|
return errIPv4ConnNotAvailable
|
||||||
|
}
|
||||||
|
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.wgRelayedEndpointAddr = addr
|
p.wgRelayedEndpointAddr = addr
|
||||||
return err
|
p.headers = headers
|
||||||
|
p.rawConn = p.selectRawConn(headers)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
|
|||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
p.headerCurrentUsed = p.headers
|
||||||
|
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
@@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
|||||||
log.Errorf("failed to start package redirection, endpoint is nil")
|
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create packet headers: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if required raw connection is available
|
||||||
|
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||||
|
log.Error(errIPv6ConnNotAvailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||||
|
log.Error(errIPv4ConnNotAvailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
p.wgEndpointCurrentUsedAddr = endpoint
|
p.headerCurrentUsed = header
|
||||||
|
p.rawConn = p.selectRawConn(header)
|
||||||
|
|
||||||
p.pausedCond.Signal()
|
p.pausedCond.Signal()
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
@@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
|||||||
p.pausedCond.Wait()
|
p.pausedCond.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||||
|
defer func() {
|
||||||
|
if err := header.layerBuffer.Clear(); err != nil {
|
||||||
|
log.Errorf("failed to clear layer buffer: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := gopacket.Payload(data)
|
||||||
|
|
||||||
|
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||||
|
return fmt.Errorf("serialize layers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||||
|
return fmt.Errorf("write to raw conn: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||||
|
if header.isIPv4 {
|
||||||
|
return p.wgeBPFProxy.rawConnIPv4
|
||||||
|
}
|
||||||
|
return p.wgeBPFProxy.rawConnIPv6
|
||||||
|
}
|
||||||
|
|||||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||||
|
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||||
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||||
|
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||||
|
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||||
|
// up when they're removed from the network map in a subsequent update.
|
||||||
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: add deny and accept rules
|
||||||
|
networkMap1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap1, false)
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||||
|
|
||||||
|
// Second update: remove the deny rule, keep only accept
|
||||||
|
networkMap2 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap2, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Should have 1 rule after removing deny rule")
|
||||||
|
|
||||||
|
// Third update: remove all rules
|
||||||
|
networkMap3 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{},
|
||||||
|
FirewallRulesIsEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap3, false)
|
||||||
|
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||||
|
"Should have 0 rules after removing all rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||||
|
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||||
|
// one added without leaking.
|
||||||
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: accept rule
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
|
// Second update: change to deny (same IP/port/proto, different action)
|
||||||
|
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Changing action should result in exactly 1 rule, not 2")
|
||||||
|
}
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|||||||
@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
matchSubdomains: false,
|
matchSubdomains: false,
|
||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD exact match",
|
||||||
|
handlerDomain: "example.x.",
|
||||||
|
queryDomain: "example.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD subdomain match",
|
||||||
|
handlerDomain: "example.x.",
|
||||||
|
queryDomain: "sub.example.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD wildcard match",
|
||||||
|
handlerDomain: "*.example.x.",
|
||||||
|
queryDomain: "sub.example.x.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two letter domain labels",
|
||||||
|
handlerDomain: "a.b.",
|
||||||
|
queryDomain: "a.b.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character domain",
|
||||||
|
handlerDomain: "x.",
|
||||||
|
queryDomain: "x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character domain with subdomain match",
|
||||||
|
handlerDomain: "x.",
|
||||||
|
queryDomain: "sub.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@@ -38,6 +40,9 @@ const (
|
|||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
systemDNSSettings SystemDNSSettings
|
systemDNSSettings SystemDNSSettings
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (*systemConfigurator, error) {
|
func newHostManager() (*systemConfigurator, error) {
|
||||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var dnsSettings SystemDNSSettings
|
var dnsSettings SystemDNSSettings
|
||||||
|
var serverAddresses []netip.Addr
|
||||||
inSearchDomainsArray := false
|
inSearchDomainsArray := false
|
||||||
inServerAddressesArray := false
|
inServerAddressesArray := false
|
||||||
|
|
||||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||||
} else if inServerAddressesArray {
|
} else if inServerAddressesArray {
|
||||||
address := strings.Split(line, " : ")[1]
|
address := strings.Split(line, " : ")[1]
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||||
dnsSettings.ServerIP = ip.Unmap()
|
ip = ip.Unmap()
|
||||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
serverAddresses = append(serverAddresses, ip)
|
||||||
|
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||||
|
dnsSettings.ServerIP = ip
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
// default to 53 port
|
// default to 53 port
|
||||||
dnsSettings.ServerPort = DefaultPort
|
dnsSettings.ServerPort = DefaultPort
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.origNameservers = serverAddresses
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
return dnsSettings, nil
|
return dnsSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return slices.Clone(s.origNameservers)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
err := s.addDNSState(key, domains, ip, port, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
|||||||
_, err := cmd.CombinedOutput()
|
_, err := cmd.CombinedOutput()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameservers(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
origNameservers: []netip.Addr{
|
||||||
|
netip.MustParseAddr("8.8.8.8"),
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
assert.Len(t, servers, 2)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||||
|
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
|
||||||
|
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||||
|
|
||||||
|
for _, server := range servers {
|
||||||
|
assert.True(t, server.IsValid(), "server address should be valid")
|
||||||
|
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
_ = sm.Stop(context.Background())
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return configurator, sm, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
routeAll bool
|
||||||
|
}{
|
||||||
|
{"routeall_false", false},
|
||||||
|
{"routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
for _, srv := range initialServers {
|
||||||
|
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.routeAll,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= 2; i++ {
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initialRoute bool
|
||||||
|
}{
|
||||||
|
{"start_with_routeall_false", false},
|
||||||
|
{"start_with_routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.initialRoute,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// First apply
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle RouteAll
|
||||||
|
config.RouteAll = !tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle back
|
||||||
|
config.RouteAll = tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
for _, srv := range servers {
|
||||||
|
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -27,6 +29,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
// ProbeAvailability tests each upstream group's servers for availability
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
// and deactivates the group if no server responds
|
// and deactivates the group if no server responds
|
||||||
func (s *DefaultServer) ProbeAvailability() {
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
|
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||||
|
skipProbe, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||||
|
}
|
||||||
|
if skipProbe {
|
||||||
|
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, mux := range s.dnsMuxMap {
|
for _, mux := range s.dnsMuxMap {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -615,7 +630,7 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
s.registerFallback(config)
|
s.registerFallback(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -624,6 +639,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
|
|
||||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||||
if len(originalNameservers) == 0 {
|
if len(originalNameservers) == 0 {
|
||||||
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,21 @@ import (
|
|||||||
|
|
||||||
type MockResponseWriter struct {
|
type MockResponseWriter struct {
|
||||||
WriteMsgFunc func(m *dns.Msg) error
|
WriteMsgFunc func(m *dns.Msg) error
|
||||||
|
lastResponse *dns.Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||||
|
rw.lastResponse = m
|
||||||
if rw.WriteMsgFunc != nil {
|
if rw.WriteMsgFunc != nil {
|
||||||
return rw.WriteMsgFunc(m)
|
return rw.WriteMsgFunc(m)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||||
|
return rw.lastResponse
|
||||||
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
|||||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
qname := strings.ToLower(question.Name)
|
||||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
// query doesn't match any configured domain
|
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||||
f.cache.set(domain, question.Qtype, result.IPs)
|
f.cache.set(qname, question.Qtype, result.IPs)
|
||||||
|
|
||||||
return resp
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
|
type udpResponseWriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
query *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||||
|
opt := u.query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ResponseWriter.WriteMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt := query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
// client advertised a larger EDNS0 buffer
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if our response is too big, truncate and set the TC bit
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
if resp != nil {
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
}
|
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else if resp != nil {
|
} else {
|
||||||
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2, "expected response to be written")
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2)
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
var writtenResp *dns.Msg
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writtenResp = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "Expected response to be written")
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||||
if resp != nil && writtenResp == nil {
|
|
||||||
writtenResp = resp
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
writeCalled := false
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writeCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
@@ -543,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
wgIfaceName := e.wgInterface.Name()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.triggerClientRestart()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -573,9 +575,11 @@ func (e *Engine) createFirewall() error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
return fmt.Errorf("create firewall manager: %w", err)
|
||||||
return nil
|
}
|
||||||
|
if e.firewall == nil {
|
||||||
|
return fmt.Errorf("create firewall manager: received nil manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.initFirewall(); err != nil {
|
if err := e.initFirewall(); err != nil {
|
||||||
@@ -826,6 +830,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.Infof("sync finished in %s", time.Since(started))
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1015,7 +1023,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(state)
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
@@ -1916,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor {
|
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
|
|
||||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerInfo) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
|||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configMgr := sshconfig.New()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindListener is only used on Windows and JS platforms:
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
// - JS: Cannot listen to UDP sockets
|
// - JS: Cannot listen to UDP sockets
|
||||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
// gateway points to, preventing them from reaching the loopback interface.
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
// BindListener bypasses this by passing data directly through the bind.
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||||
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
|
||||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||||
conn.handleConfigurationFailure(err, wgProxy)
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
conn.wgProxyRelay.RedirectAs(ep)
|
conn.wgProxyRelay.RedirectAs(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.enableWgWatcherIfNeeded()
|
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.SetConnected()
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||||
|
|
||||||
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
|
||||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.enableWgWatcherIfNeeded()
|
|
||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ThreadSafeAgent) Close() error {
|
|
||||||
var err error
|
|
||||||
a.once.Do(func() {
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- a.Agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if agent == nil {
|
||||||
|
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||||
|
}
|
||||||
|
|
||||||
return &ThreadSafeAgent{Agent: agent}, nil
|
return &ThreadSafeAgent{Agent: agent}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
// Defensive check to prevent nil pointer dereference
|
||||||
|
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||||
|
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||||
|
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||||
|
agent := a.Agent
|
||||||
|
if agent == nil {
|
||||||
|
log.Warnf("ICE agent is nil during close, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateICECredentials() (string, string, error) {
|
func GenerateICECredentials() (string, string, error) {
|
||||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if err := w.agent.Close(); err != nil {
|
if w.agent != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
sessionID, err := NewICESessionID()
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.AdminURL == nil {
|
if config.AdminURL == nil {
|
||||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
|
var once sync.Once
|
||||||
|
var wgIface *net.Interface
|
||||||
|
toInterface := func() *net.Interface {
|
||||||
|
once.Do(func() {
|
||||||
|
wgIface = m.wgInterface.ToInterface()
|
||||||
|
})
|
||||||
|
return wgIface
|
||||||
|
}
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_CLONING != 0 {
|
if flags&unix.RTF_CLONING != 0 {
|
||||||
flagStrs = append(flagStrs, "C")
|
flagStrs = append(flagStrs, "C")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_WASCLONED != 0 {
|
if flags&unix.RTF_WASCLONED != 0 {
|
||||||
flagStrs = append(flagStrs, "W")
|
flagStrs = append(flagStrs, "W")
|
||||||
}
|
}
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -4,17 +4,18 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||||
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
|||||||
return false, errors.New("not supported on mobile platforms")
|
return false, errors.New("not supported on mobile platforms")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Interface monitor: skipped in netstack mode")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
if ifaceName == "" {
|
if ifaceName == "" {
|
||||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||||
return false, errors.New("empty interface name")
|
return false, errors.New("empty interface name")
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -42,6 +42,7 @@ require (
|
|||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coder/websocket v1.8.13
|
github.com/coder/websocket v1.8.13
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
|
github.com/coreos/go-oidc/v3 v3.14.1
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||||
github.com/dexidp/dex/api/v2 v2.4.0
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
@@ -68,7 +69,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/oapi-codegen/runtime v1.1.2
|
github.com/oapi-codegen/runtime v1.1.2
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
@@ -167,7 +168,6 @@ require (
|
|||||||
github.com/containerd/containerd v1.7.29 // indirect
|
github.com/containerd/containerd v1.7.29 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/coreos/go-oidc/v3 v3.14.1 // indirect
|
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
|
||||||
|
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||||
|
connectors, err := p.storage.ListConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to list connectors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
|
||||||
|
for _, conn := range connectors {
|
||||||
|
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
|
||||||
|
if conn.ID != "local" || conn.Type != "local" {
|
||||||
|
p.logger.Info("found non-local connector", "id", conn.ID)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.logger.Info("no non-local connectors found")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableLocalAuth removes the local (password) connector.
|
||||||
|
// Returns an error if no other connectors are configured.
|
||||||
|
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
|
||||||
|
hasOthers, err := p.HasNonLocalConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !hasOthers {
|
||||||
|
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if local connector exists
|
||||||
|
_, err = p.storage.GetConnector(ctx, "local")
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
// Already disabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check local connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the local connector
|
||||||
|
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete local connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("local authentication disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
|
||||||
|
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
|
||||||
|
return ensureLocalConnector(ctx, p.storage)
|
||||||
|
}
|
||||||
|
|
||||||
// ensureStaticConnectors creates or updates static connectors in storage
|
// ensureStaticConnectors creates or updates static connectors in storage
|
||||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||||
for _, conn := range connectors {
|
for _, conn := range connectors {
|
||||||
|
|||||||
@@ -91,16 +91,17 @@ read_reverse_proxy_type() {
|
|||||||
echo " [3] Nginx Proxy Manager (generates config + instructions)" > /dev/stderr
|
echo " [3] Nginx Proxy Manager (generates config + instructions)" > /dev/stderr
|
||||||
echo " [4] External Caddy (generates Caddyfile snippet)" > /dev/stderr
|
echo " [4] External Caddy (generates Caddyfile snippet)" > /dev/stderr
|
||||||
echo " [5] Other/Manual (displays setup documentation)" > /dev/stderr
|
echo " [5] Other/Manual (displays setup documentation)" > /dev/stderr
|
||||||
|
echo " [6] Traefik TCP Proxy (single port 443 + STUN)" > /dev/stderr
|
||||||
echo "" > /dev/stderr
|
echo "" > /dev/stderr
|
||||||
echo -n "Enter choice [0-5] (default: 0): " > /dev/stderr
|
echo -n "Enter choice [0-6] (default: 0): " > /dev/stderr
|
||||||
read -r CHOICE < /dev/tty
|
read -r CHOICE < /dev/tty
|
||||||
|
|
||||||
if [[ -z "$CHOICE" ]]; then
|
if [[ -z "$CHOICE" ]]; then
|
||||||
CHOICE="0"
|
CHOICE="0"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ! "$CHOICE" =~ ^[0-5]$ ]]; then
|
if [[ ! "$CHOICE" =~ ^[0-6]$ ]]; then
|
||||||
echo "Invalid choice. Please enter a number between 0 and 5." > /dev/stderr
|
echo "Invalid choice. Please enter a number between 0 and 6." > /dev/stderr
|
||||||
read_reverse_proxy_type
|
read_reverse_proxy_type
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
@@ -140,6 +141,35 @@ read_traefik_certresolver() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
read_traefik_tcp_acme_email() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
|
||||||
|
echo -n "Email address: " > /dev/stderr
|
||||||
|
read -r EMAIL < /dev/tty
|
||||||
|
if [[ -z "$EMAIL" ]]; then
|
||||||
|
echo "Email is required for Let's Encrypt." > /dev/stderr
|
||||||
|
read_traefik_tcp_acme_email
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "$EMAIL"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
read_enable_proxy() {
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Do you want to enable the NetBird Proxy service?" > /dev/stderr
|
||||||
|
echo "The proxy exposes internal NetBird network resources to the internet." > /dev/stderr
|
||||||
|
echo -n "Enable proxy? [y/N]: " > /dev/stderr
|
||||||
|
read -r CHOICE < /dev/tty
|
||||||
|
|
||||||
|
if [[ "$CHOICE" =~ ^[Yy]$ ]]; then
|
||||||
|
echo "true"
|
||||||
|
else
|
||||||
|
echo "false"
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
read_port_binding_preference() {
|
read_port_binding_preference() {
|
||||||
echo "" > /dev/stderr
|
echo "" > /dev/stderr
|
||||||
echo "Should container ports be bound to localhost only (127.0.0.1)?" > /dev/stderr
|
echo "Should container ports be bound to localhost only (127.0.0.1)?" > /dev/stderr
|
||||||
@@ -206,6 +236,30 @@ wait_management() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wait_management_traefik() {
|
||||||
|
set +e
|
||||||
|
echo -n "Waiting for Management server to become ready"
|
||||||
|
counter=1
|
||||||
|
while true; do
|
||||||
|
# Check the embedded IdP endpoint through Traefik
|
||||||
|
if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2/.well-known/openid-configuration" 2>/dev/null; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [[ $counter -eq 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Taking too long. Checking logs..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 traefik
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 management
|
||||||
|
fi
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
set -e
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
wait_management_direct() {
|
wait_management_direct() {
|
||||||
set +e
|
set +e
|
||||||
local upstream_host=$(get_upstream_host)
|
local upstream_host=$(get_upstream_host)
|
||||||
@@ -246,10 +300,12 @@ initialize_default_values() {
|
|||||||
|
|
||||||
# Docker images
|
# Docker images
|
||||||
CADDY_IMAGE="caddy"
|
CADDY_IMAGE="caddy"
|
||||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
#DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||||
|
DASHBOARD_IMAGE="netbirdio/dashboard:pr-552"
|
||||||
SIGNAL_IMAGE="netbirdio/signal:latest"
|
SIGNAL_IMAGE="netbirdio/signal:latest"
|
||||||
RELAY_IMAGE="netbirdio/relay:latest"
|
RELAY_IMAGE="netbirdio/relay:latest"
|
||||||
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
MANAGEMENT_IMAGE="netbirdio/management:latest"
|
||||||
|
PROXY_IMAGE=""
|
||||||
|
|
||||||
# Reverse proxy configuration
|
# Reverse proxy configuration
|
||||||
REVERSE_PROXY_TYPE="0"
|
REVERSE_PROXY_TYPE="0"
|
||||||
@@ -263,6 +319,14 @@ initialize_default_values() {
|
|||||||
RELAY_HOST_PORT="8084"
|
RELAY_HOST_PORT="8084"
|
||||||
BIND_LOCALHOST_ONLY="true"
|
BIND_LOCALHOST_ONLY="true"
|
||||||
EXTERNAL_PROXY_NETWORK=""
|
EXTERNAL_PROXY_NETWORK=""
|
||||||
|
|
||||||
|
# Traefik TCP proxy configuration
|
||||||
|
TRAEFIK_IMAGE="traefik:v3.6"
|
||||||
|
TRAEFIK_TCP_ACME_EMAIL=""
|
||||||
|
|
||||||
|
# NetBird Proxy configuration
|
||||||
|
ENABLE_PROXY="false"
|
||||||
|
PROXY_TOKEN=""
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,8 +357,17 @@ configure_reverse_proxy() {
|
|||||||
TRAEFIK_CERTRESOLVER=$(read_traefik_certresolver)
|
TRAEFIK_CERTRESOLVER=$(read_traefik_certresolver)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Handle Traefik TCP proxy prompts
|
||||||
|
if [[ "$REVERSE_PROXY_TYPE" == "6" ]]; then
|
||||||
|
TRAEFIK_TCP_ACME_EMAIL=$(read_traefik_tcp_acme_email)
|
||||||
|
|
||||||
|
# Prompt for NetBird Proxy configuration
|
||||||
|
ENABLE_PROXY=$(read_enable_proxy)
|
||||||
|
# Note: PROXY_TOKEN will be auto-generated after Management starts
|
||||||
|
fi
|
||||||
|
|
||||||
# Handle port binding for external proxy options (2-5)
|
# Handle port binding for external proxy options (2-5)
|
||||||
if [[ "$REVERSE_PROXY_TYPE" -ge 2 ]]; then
|
if [[ "$REVERSE_PROXY_TYPE" -ge 2 && "$REVERSE_PROXY_TYPE" -le 5 ]]; then
|
||||||
BIND_LOCALHOST_ONLY=$(read_port_binding_preference)
|
BIND_LOCALHOST_ONLY=$(read_port_binding_preference)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -313,7 +386,7 @@ check_existing_installation() {
|
|||||||
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
|
||||||
echo "You can use the following commands:"
|
echo "You can use the following commands:"
|
||||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
|
||||||
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
|
echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt proxy.env"
|
||||||
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
@@ -347,6 +420,15 @@ generate_configuration_files() {
|
|||||||
5)
|
5)
|
||||||
render_docker_compose_exposed_ports > docker-compose.yml
|
render_docker_compose_exposed_ports > docker-compose.yml
|
||||||
;;
|
;;
|
||||||
|
6)
|
||||||
|
render_docker_compose_traefik_tcp > docker-compose.yml
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
# Create placeholder proxy.env so docker-compose can validate
|
||||||
|
# This will be overwritten with the actual token after Management starts
|
||||||
|
echo "# Placeholder - will be updated with token after Management starts" > proxy.env
|
||||||
|
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
|
||||||
|
fi
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Invalid reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
echo "Invalid reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
||||||
exit 1
|
exit 1
|
||||||
@@ -402,6 +484,50 @@ start_services_and_show_instructions() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "NetBird containers are running. Configure NPM as shown above, then access:"
|
echo "NetBird containers are running. Configure NPM as shown above, then access:"
|
||||||
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
|
elif [[ "$REVERSE_PROXY_TYPE" == "6" ]]; then
|
||||||
|
# Traefik TCP Proxy - two-phase startup if proxy is enabled
|
||||||
|
echo -e "$MSG_STARTING_SERVICES"
|
||||||
|
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
# Phase 1: Start core services (without proxy)
|
||||||
|
echo "Starting core services..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d traefik dashboard signal relay management
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
wait_management_traefik
|
||||||
|
|
||||||
|
# Phase 2: Create proxy token and start proxy
|
||||||
|
echo ""
|
||||||
|
echo "Creating proxy access token..."
|
||||||
|
# Use docker exec with bash to run the token command directly
|
||||||
|
# (bypassing the entrypoint which adds 'management' as first arg)
|
||||||
|
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T management \
|
||||||
|
bash -c '/go/bin/netbird-mgmt token create --name "default-proxy" --config /etc/netbird/management.json' 2>/dev/null | grep "^Token:" | awk '{print $2}')
|
||||||
|
|
||||||
|
if [[ -z "$PROXY_TOKEN" ]]; then
|
||||||
|
echo "ERROR: Failed to create proxy token. Check management logs." > /dev/stderr
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 management
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Proxy token created successfully."
|
||||||
|
|
||||||
|
# Generate proxy.env with the token
|
||||||
|
render_proxy_env > proxy.env
|
||||||
|
|
||||||
|
# Start proxy service
|
||||||
|
echo "Starting proxy service..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d proxy
|
||||||
|
else
|
||||||
|
# No proxy - start all services at once
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
wait_management_traefik
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "$MSG_DONE"
|
||||||
|
print_post_setup_instructions
|
||||||
else
|
else
|
||||||
# External proxies (nginx, external Caddy, other) - need manual config first
|
# External proxies (nginx, external Caddy, other) - need manual config first
|
||||||
print_post_setup_instructions
|
print_post_setup_instructions
|
||||||
@@ -547,6 +673,29 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
render_proxy_env() {
|
||||||
|
cat <<EOF
|
||||||
|
# NetBird Proxy Configuration
|
||||||
|
NB_PROXY_DEBUG_LOGS=false
|
||||||
|
# Use internal Docker network to connect to management (avoids hairpin NAT issues)
|
||||||
|
NB_PROXY_MANAGEMENT_ADDRESS=http://management:80
|
||||||
|
# Allow insecure gRPC connection to management (required for internal Docker network)
|
||||||
|
NB_PROXY_ALLOW_INSECURE=true
|
||||||
|
# Public URL where this proxy is reachable (used for cluster registration)
|
||||||
|
NB_PROXY_DOMAIN=$NETBIRD_DOMAIN
|
||||||
|
NB_PROXY_ADDRESS=:8443
|
||||||
|
NB_PROXY_TOKEN=$PROXY_TOKEN
|
||||||
|
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
|
||||||
|
NB_PROXY_ACME_CERTIFICATES=true
|
||||||
|
NB_PROXY_ACME_CHALLENGE_TYPE=tls-alpn-01
|
||||||
|
NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
|
||||||
|
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||||
|
NB_PROXY_OIDC_SCOPES=openid,profile,email
|
||||||
|
NB_PROXY_FORWARDED_PROTO=https
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
render_docker_compose() {
|
render_docker_compose() {
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
services:
|
services:
|
||||||
@@ -736,7 +885,18 @@ $(if [[ -n "$tls_labels" ]]; then echo " - traefik.http.routers.netbird-rel
|
|||||||
|
|
||||||
# Management (includes embedded IdP)
|
# Management (includes embedded IdP)
|
||||||
management:
|
management:
|
||||||
|
$(if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
cat <<MGMT_BUILD
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: management/Dockerfile.multistage
|
||||||
|
pull_policy: build
|
||||||
|
MGMT_BUILD
|
||||||
|
else
|
||||||
|
cat <<MGMT_IMAGE
|
||||||
image: $MANAGEMENT_IMAGE
|
image: $MANAGEMENT_IMAGE
|
||||||
|
MGMT_IMAGE
|
||||||
|
fi)
|
||||||
container_name: netbird-management
|
container_name: netbird-management
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks: [$network_name]
|
networks: [$network_name]
|
||||||
@@ -1115,6 +1275,258 @@ EOF
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
render_docker_compose_traefik_tcp() {
|
||||||
|
# Generate proxy service section if enabled
|
||||||
|
local proxy_service=""
|
||||||
|
local proxy_volumes=""
|
||||||
|
local proxy_tcp_labels=""
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
proxy_service="
|
||||||
|
# NetBird Proxy - exposes internal resources to the internet
|
||||||
|
proxy:
|
||||||
|
build:
|
||||||
|
context: ../
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
# Always rebuild to pick up code changes during testing
|
||||||
|
pull_policy: build
|
||||||
|
#image: $PROXY_IMAGE
|
||||||
|
container_name: netbird-proxy
|
||||||
|
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
|
||||||
|
extra_hosts:
|
||||||
|
- \"$NETBIRD_DOMAIN:172.30.0.10\"
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
- signal
|
||||||
|
env_file:
|
||||||
|
- ./proxy.env
|
||||||
|
volumes:
|
||||||
|
- netbird_proxy_certs:/certs
|
||||||
|
labels:
|
||||||
|
# TCP passthrough for any unmatched domain (proxy handles its own TLS)
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.entrypoints=websecure
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.rule=HostSNI(\`*\`)
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.tls.passthrough=true
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
|
||||||
|
- traefik.tcp.routers.proxy-passthrough.priority=1
|
||||||
|
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
|
||||||
|
logging:
|
||||||
|
driver: \"json-file\"
|
||||||
|
options:
|
||||||
|
max-size: \"500m\"
|
||||||
|
max-file: \"2\"
|
||||||
|
"
|
||||||
|
proxy_volumes="
|
||||||
|
netbird_proxy_certs:"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
services:
|
||||||
|
# Traefik - single port 443 entry point with TLS termination
|
||||||
|
traefik:
|
||||||
|
image: $TRAEFIK_IMAGE
|
||||||
|
container_name: netbird-traefik
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
ipv4_address: 172.30.0.10
|
||||||
|
ports:
|
||||||
|
- '443:443'
|
||||||
|
volumes:
|
||||||
|
- netbird_traefik_data:/data
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
|
command:
|
||||||
|
# Logging
|
||||||
|
- --log.level=INFO
|
||||||
|
- --accesslog=true
|
||||||
|
# Docker provider
|
||||||
|
- --providers.docker=true
|
||||||
|
- --providers.docker.exposedbydefault=false
|
||||||
|
- --providers.docker.network=netbird
|
||||||
|
# Entrypoints
|
||||||
|
- --entrypoints.websecure.address=:443
|
||||||
|
- --entrypoints.websecure.allowACMEByPass=true
|
||||||
|
# Disable timeouts for long-lived gRPC streams
|
||||||
|
- --entrypoints.websecure.transport.respondingtimeouts.readtimeout=0s
|
||||||
|
- --entrypoints.websecure.transport.respondingtimeouts.writetimeout=0s
|
||||||
|
- --entrypoints.websecure.transport.respondingtimeouts.idletimeout=0s
|
||||||
|
# Let's Encrypt ACME
|
||||||
|
- --certificatesresolvers.letsencrypt.acme.email=$TRAEFIK_TCP_ACME_EMAIL
|
||||||
|
- --certificatesresolvers.letsencrypt.acme.storage=/data/acme.json
|
||||||
|
- --certificatesresolvers.letsencrypt.acme.tlschallenge=true
|
||||||
|
# gRPC transport settings (disable response timeout for long-lived streams)
|
||||||
|
- --serverstransport.forwardingtimeouts.responseheadertimeout=0s
|
||||||
|
- --serverstransport.forwardingtimeouts.idleconntimeout=0s
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# UI dashboard
|
||||||
|
dashboard:
|
||||||
|
image: $DASHBOARD_IMAGE
|
||||||
|
container_name: netbird-dashboard
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
env_file:
|
||||||
|
- ./dashboard.env
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.netbird-dashboard.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-dashboard.rule=Host(\`$NETBIRD_DOMAIN\`)
|
||||||
|
- traefik.http.routers.netbird-dashboard.tls=true
|
||||||
|
- traefik.http.routers.netbird-dashboard.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-dashboard.service=dashboard
|
||||||
|
- traefik.http.routers.netbird-dashboard.priority=1
|
||||||
|
- traefik.http.services.dashboard.loadbalancer.server.port=80
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Signal
|
||||||
|
signal:
|
||||||
|
image: $SIGNAL_IMAGE
|
||||||
|
container_name: netbird-signal
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
# Signal WebSocket
|
||||||
|
- traefik.http.routers.netbird-signal-ws.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-signal-ws.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/ws-proxy/signal\`)
|
||||||
|
- traefik.http.routers.netbird-signal-ws.tls=true
|
||||||
|
- traefik.http.routers.netbird-signal-ws.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-signal-ws.service=signal-ws
|
||||||
|
- traefik.http.routers.netbird-signal-ws.priority=100
|
||||||
|
- traefik.http.services.signal-ws.loadbalancer.server.port=80
|
||||||
|
# Signal gRPC
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/signalexchange.SignalExchange/\`)
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.tls=true
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.service=signal-grpc
|
||||||
|
- traefik.http.routers.netbird-signal-grpc.priority=100
|
||||||
|
- traefik.http.services.signal-grpc.loadbalancer.server.port=10000
|
||||||
|
- traefik.http.services.signal-grpc.loadbalancer.server.scheme=h2c
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Relay (includes embedded STUN server)
|
||||||
|
relay:
|
||||||
|
image: $RELAY_IMAGE
|
||||||
|
container_name: netbird-relay
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
ports:
|
||||||
|
- '$NETBIRD_STUN_PORT:$NETBIRD_STUN_PORT/udp'
|
||||||
|
env_file:
|
||||||
|
- ./relay.env
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.netbird-relay.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-relay.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/relay\`)
|
||||||
|
- traefik.http.routers.netbird-relay.tls=true
|
||||||
|
- traefik.http.routers.netbird-relay.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-relay.service=relay
|
||||||
|
- traefik.http.routers.netbird-relay.priority=100
|
||||||
|
- traefik.http.services.relay.loadbalancer.server.port=80
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
|
||||||
|
# Management (includes embedded IdP)
|
||||||
|
management:
|
||||||
|
$(if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
cat <<MGMT_BUILD
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: management/Dockerfile.multistage
|
||||||
|
pull_policy: build
|
||||||
|
MGMT_BUILD
|
||||||
|
else
|
||||||
|
cat <<MGMT_IMAGE
|
||||||
|
image: $MANAGEMENT_IMAGE
|
||||||
|
MGMT_IMAGE
|
||||||
|
fi)
|
||||||
|
container_name: netbird-management
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [netbird]
|
||||||
|
volumes:
|
||||||
|
- netbird_management:/var/lib/netbird
|
||||||
|
- ./management.json:/etc/netbird/management.json
|
||||||
|
command: [
|
||||||
|
"--port", "80",
|
||||||
|
"--log-file", "console",
|
||||||
|
"--log-level", "info",
|
||||||
|
"--disable-anonymous-metrics=false",
|
||||||
|
"--single-account-mode-domain=netbird.selfhosted",
|
||||||
|
"--dns-domain=netbird.selfhosted",
|
||||||
|
"--idp-sign-key-refresh-enabled",
|
||||||
|
]
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
# Management API
|
||||||
|
- traefik.http.routers.netbird-api.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-api.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/api\`)
|
||||||
|
- traefik.http.routers.netbird-api.tls=true
|
||||||
|
- traefik.http.routers.netbird-api.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-api.service=management
|
||||||
|
- traefik.http.routers.netbird-api.priority=100
|
||||||
|
# Management WebSocket
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/ws-proxy/management\`)
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.tls=true
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.service=management
|
||||||
|
- traefik.http.routers.netbird-mgmt-ws.priority=100
|
||||||
|
# Management gRPC
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/management.ManagementService/\`)
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.tls=true
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.service=management-grpc
|
||||||
|
- traefik.http.routers.netbird-mgmt-grpc.priority=100
|
||||||
|
# OAuth2 (embedded IdP)
|
||||||
|
- traefik.http.routers.netbird-oauth2.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-oauth2.rule=Host(\`$NETBIRD_DOMAIN\`) && PathPrefix(\`/oauth2\`)
|
||||||
|
- traefik.http.routers.netbird-oauth2.tls=true
|
||||||
|
- traefik.http.routers.netbird-oauth2.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-oauth2.service=management
|
||||||
|
- traefik.http.routers.netbird-oauth2.priority=100
|
||||||
|
# Services
|
||||||
|
- traefik.http.services.management.loadbalancer.server.port=80
|
||||||
|
- traefik.http.services.management-grpc.loadbalancer.server.port=80
|
||||||
|
- traefik.http.services.management-grpc.loadbalancer.server.scheme=h2c
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "500m"
|
||||||
|
max-file: "2"
|
||||||
|
${proxy_service}
|
||||||
|
volumes:
|
||||||
|
netbird_traefik_data:
|
||||||
|
netbird_management:${proxy_volumes}
|
||||||
|
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
driver: bridge
|
||||||
|
ipam:
|
||||||
|
config:
|
||||||
|
- subnet: 172.30.0.0/24
|
||||||
|
gateway: 172.30.0.1
|
||||||
|
EOF
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
render_npm_advanced_config() {
|
render_npm_advanced_config() {
|
||||||
local upstream_host=$(get_upstream_host)
|
local upstream_host=$(get_upstream_host)
|
||||||
local relay_addr="${upstream_host}:${RELAY_HOST_PORT}"
|
local relay_addr="${upstream_host}:${RELAY_HOST_PORT}"
|
||||||
@@ -1424,6 +1836,36 @@ print_manual_instructions() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print_traefik_tcp_instructions() {
|
||||||
|
echo ""
|
||||||
|
echo "$MSG_SEPARATOR"
|
||||||
|
echo " TRAEFIK TCP PROXY SETUP"
|
||||||
|
echo "$MSG_SEPARATOR"
|
||||||
|
echo ""
|
||||||
|
echo "This configuration uses Traefik as a single entry point on port 443."
|
||||||
|
echo "Traefik handles TLS termination with Let's Encrypt and routes to services."
|
||||||
|
echo ""
|
||||||
|
echo "Open ports:"
|
||||||
|
echo " - 443/tcp (HTTPS - all NetBird services)"
|
||||||
|
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
||||||
|
echo ""
|
||||||
|
echo "Generated files:"
|
||||||
|
echo " - docker-compose.yml (container definitions with Traefik labels)"
|
||||||
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
|
echo " - proxy.env (NetBird Proxy configuration)"
|
||||||
|
echo ""
|
||||||
|
echo "NetBird Proxy:"
|
||||||
|
echo " The proxy service is enabled and will be built from source."
|
||||||
|
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||||
|
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
|
||||||
|
echo " Point your proxy domains (CNAMEs) to this server's IP address."
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||||
|
echo "Follow the onboarding steps to set up your NetBird instance."
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
print_post_setup_instructions() {
|
print_post_setup_instructions() {
|
||||||
case "$REVERSE_PROXY_TYPE" in
|
case "$REVERSE_PROXY_TYPE" in
|
||||||
0)
|
0)
|
||||||
@@ -1444,6 +1886,9 @@ print_post_setup_instructions() {
|
|||||||
5)
|
5)
|
||||||
print_manual_instructions
|
print_manual_instructions
|
||||||
;;
|
;;
|
||||||
|
6)
|
||||||
|
print_traefik_tcp_instructions
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
echo "Unknown reverse proxy type: $REVERSE_PROXY_TYPE" > /dev/stderr
|
||||||
;;
|
;;
|
||||||
|
|||||||
17
management/Dockerfile.multistage
Normal file
17
management/Dockerfile.multistage
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||||
|
CMD ["--log-file", "console"]
|
||||||
|
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt
|
||||||
@@ -16,13 +16,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/server"
|
"github.com/netbirdio/netbird/management/internals/server"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/util/crypt"
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
)
|
)
|
||||||
@@ -78,9 +80,8 @@ var (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, valid := dns.IsDomainName(dnsDomain)
|
if !nbdomain.IsValidDomainNoWildcard(dnsDomain) {
|
||||||
if !valid || len(dnsDomain) > 192 {
|
return fmt.Errorf("invalid dns-domain: %s", dnsDomain)
|
||||||
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -214,11 +215,14 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
|||||||
// Set HttpConfig values from EmbeddedIdP
|
// Set HttpConfig values from EmbeddedIdP
|
||||||
cfg.HttpConfig.AuthIssuer = issuer
|
cfg.HttpConfig.AuthIssuer = issuer
|
||||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
|
||||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||||
|
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
|
||||||
|
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,4 +80,10 @@ func init() {
|
|||||||
migrationCmd.AddCommand(upCmd)
|
migrationCmd.AddCommand(upCmd)
|
||||||
|
|
||||||
rootCmd.AddCommand(migrationCmd)
|
rootCmd.AddCommand(migrationCmd)
|
||||||
|
|
||||||
|
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||||
|
tokenCmd.AddCommand(tokenCreateCmd)
|
||||||
|
tokenCmd.AddCommand(tokenListCmd)
|
||||||
|
tokenCmd.AddCommand(tokenRevokeCmd)
|
||||||
|
rootCmd.AddCommand(tokenCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
209
management/cmd/token.go
Normal file
209
management/cmd/token.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
tokenName string
|
||||||
|
tokenExpireIn string
|
||||||
|
tokenDatadir string
|
||||||
|
|
||||||
|
tokenCmd = &cobra.Command{
|
||||||
|
Use: "token",
|
||||||
|
Short: "Manage proxy access tokens",
|
||||||
|
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCreateCmd = &cobra.Command{
|
||||||
|
Use: "create",
|
||||||
|
Short: "Create a new proxy access token",
|
||||||
|
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||||
|
RunE: tokenCreateRun,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List all proxy access tokens",
|
||||||
|
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||||
|
RunE: tokenListRun,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenRevokeCmd = &cobra.Command{
|
||||||
|
Use: "revoke [token-id]",
|
||||||
|
Short: "Revoke a proxy access token",
|
||||||
|
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: tokenRevokeRun,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||||
|
|
||||||
|
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||||
|
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||||
|
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||||
|
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
||||||
|
|
||||||
|
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := config.Datadir
|
||||||
|
if tokenDatadir != "" {
|
||||||
|
datadir = tokenDatadir
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := s.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
||||||
|
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
expiresIn, err := parseDuration(tokenExpireIn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse expiration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||||
|
return fmt.Errorf("save token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Token created successfully!") //nolint:forbidigo
|
||||||
|
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
|
||||||
|
fmt.Println() //nolint:forbidigo
|
||||||
|
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
|
||||||
|
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
||||||
|
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||||
|
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
expires := "never"
|
||||||
|
if t.ExpiresAt != nil {
|
||||||
|
expires = t.ExpiresAt.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
|
||||||
|
lastUsed := "never"
|
||||||
|
if t.LastUsed != nil {
|
||||||
|
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||||
|
}
|
||||||
|
|
||||||
|
revoked := "no"
|
||||||
|
if t.Revoked {
|
||||||
|
revoked = "yes"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||||
|
t.ID,
|
||||||
|
t.Name,
|
||||||
|
t.CreatedAt.Format("2006-01-02"),
|
||||||
|
expires,
|
||||||
|
lastUsed,
|
||||||
|
revoked,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Flush()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
||||||
|
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
tokenID := args[0]
|
||||||
|
|
||||||
|
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||||
|
return fmt.Errorf("revoke token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||||
|
// An empty string returns zero duration (no expiration).
|
||||||
|
func parseDuration(s string) (time.Duration, error) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s[len(s)-1] == 'd' {
|
||||||
|
d, err := strconv.Atoi(s[:len(s)-1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return time.Duration(d) * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
101
management/cmd/token_test.go
Normal file
101
management/cmd/token_test.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected time.Duration
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string returns zero",
|
||||||
|
input: "",
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "days suffix",
|
||||||
|
input: "30d",
|
||||||
|
expected: 30 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one day",
|
||||||
|
input: "1d",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "365 days",
|
||||||
|
input: "365d",
|
||||||
|
expected: 365 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hours via Go duration",
|
||||||
|
input: "24h",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "minutes via Go duration",
|
||||||
|
input: "30m",
|
||||||
|
expected: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex Go duration",
|
||||||
|
input: "1h30m",
|
||||||
|
expected: 90 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid day format",
|
||||||
|
input: "abcd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative days",
|
||||||
|
input: "-1d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero days",
|
||||||
|
input: "0d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-numeric days",
|
||||||
|
input: "xyzd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative Go duration",
|
||||||
|
input: "-24h",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero Go duration",
|
||||||
|
input: "0s",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid Go duration",
|
||||||
|
input: "notaduration",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := parseDuration(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -174,6 +174,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
semaphore := make(chan struct{}, 10)
|
semaphore := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -247,7 +248,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +327,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -370,7 +375,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -435,6 +443,8 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
@@ -778,6 +788,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
@@ -840,6 +851,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
} else {
|
} else {
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|||||||
@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
|
|||||||
func TestSendUpdate(t *testing.T) {
|
func TestSendUpdate(t *testing.T) {
|
||||||
peer := "test-sendupdate"
|
peer := "test-sendupdate"
|
||||||
peersUpdater := NewPeersUpdateManager(nil)
|
peersUpdater := NewPeersUpdateManager(nil)
|
||||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update1 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 0,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||||
}
|
}
|
||||||
|
|
||||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update2 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 10,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
||||||
timeout := time.After(5 * time.Second)
|
timeout := time.After(5 * time.Second)
|
||||||
|
|||||||
@@ -4,6 +4,19 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MessageType indicates the type of update message for debouncing strategy
|
||||||
|
type MessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
|
||||||
|
// These updates can be safely debounced - only the latest state matters
|
||||||
|
MessageTypeNetworkMap MessageType = iota
|
||||||
|
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
|
||||||
|
// These updates should not be dropped as they contain time-sensitive information
|
||||||
|
MessageTypeControlConfig
|
||||||
|
)
|
||||||
|
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
|
MessageType MessageType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for accountID, peerIDs := range peerIDsPerAccount {
|
for accountID, peerIDs := range peerIDsPerAccount {
|
||||||
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
|
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
|
||||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -32,6 +33,7 @@ type Manager interface {
|
|||||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
|
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -108,10 +110,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
|
||||||
|
log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
|
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
|
||||||
|
log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)",
|
||||||
|
peerID, peer.Status.Connected,
|
||||||
|
peer.Status.LastSeen.Format(time.RFC3339),
|
||||||
|
time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339),
|
||||||
|
peer.Ephemeral)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +161,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.integratedPeerValidator != nil {
|
if m.integratedPeerValidator != nil {
|
||||||
@@ -172,3 +184,36 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
|
if err == nil && existingPeerID != "" {
|
||||||
|
// Peer already exists
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
||||||
|
peer := &peer.Peer{
|
||||||
|
Ephemeral: true,
|
||||||
|
ProxyMeta: peer.ProxyMeta{
|
||||||
|
Cluster: cluster,
|
||||||
|
Embedded: true,
|
||||||
|
},
|
||||||
|
Name: name,
|
||||||
|
Key: peerKey,
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
Hostname: name,
|
||||||
|
GoOS: "proxy",
|
||||||
|
OS: "proxy",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -162,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccessLogEntry struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
ServiceID string `gorm:"index"`
|
||||||
|
Timestamp time.Time `gorm:"index"`
|
||||||
|
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
Method string `gorm:"index"`
|
||||||
|
Host string `gorm:"index"`
|
||||||
|
Path string `gorm:"index"`
|
||||||
|
Duration time.Duration `gorm:"index"`
|
||||||
|
StatusCode int `gorm:"index"`
|
||||||
|
Reason string
|
||||||
|
UserId string `gorm:"index"`
|
||||||
|
AuthMethodUsed string `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||||
|
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
||||||
|
a.ID = serviceLog.GetLogId()
|
||||||
|
a.ServiceID = serviceLog.GetServiceId()
|
||||||
|
a.Timestamp = serviceLog.GetTimestamp().AsTime()
|
||||||
|
a.Method = serviceLog.GetMethod()
|
||||||
|
a.Host = serviceLog.GetHost()
|
||||||
|
a.Path = serviceLog.GetPath()
|
||||||
|
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
|
||||||
|
a.StatusCode = int(serviceLog.GetResponseCode())
|
||||||
|
a.UserId = serviceLog.GetUserId()
|
||||||
|
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
||||||
|
a.AccountID = serviceLog.GetAccountId()
|
||||||
|
|
||||||
|
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||||
|
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
||||||
|
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceLog.GetAuthSuccess() {
|
||||||
|
a.Reason = "Authentication failed"
|
||||||
|
} else if serviceLog.GetResponseCode() >= 400 {
|
||||||
|
a.Reason = "Request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToAPIResponse converts an AccessLogEntry to the API ProxyAccessLog type
|
||||||
|
func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||||
|
var sourceIP *string
|
||||||
|
if a.GeoLocation.ConnectionIP != nil {
|
||||||
|
ip := a.GeoLocation.ConnectionIP.String()
|
||||||
|
sourceIP = &ip
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason *string
|
||||||
|
if a.Reason != "" {
|
||||||
|
reason = &a.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID *string
|
||||||
|
if a.UserId != "" {
|
||||||
|
userID = &a.UserId
|
||||||
|
}
|
||||||
|
|
||||||
|
var authMethod *string
|
||||||
|
if a.AuthMethodUsed != "" {
|
||||||
|
authMethod = &a.AuthMethodUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
var countryCode *string
|
||||||
|
if a.GeoLocation.CountryCode != "" {
|
||||||
|
countryCode = &a.GeoLocation.CountryCode
|
||||||
|
}
|
||||||
|
|
||||||
|
var cityName *string
|
||||||
|
if a.GeoLocation.CityName != "" {
|
||||||
|
cityName = &a.GeoLocation.CityName
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.ProxyAccessLog{
|
||||||
|
Id: a.ID,
|
||||||
|
ServiceId: a.ServiceID,
|
||||||
|
Timestamp: a.Timestamp,
|
||||||
|
Method: a.Method,
|
||||||
|
Host: a.Host,
|
||||||
|
Path: a.Path,
|
||||||
|
DurationMs: int(a.Duration.Milliseconds()),
|
||||||
|
StatusCode: a.StatusCode,
|
||||||
|
SourceIp: sourceIP,
|
||||||
|
Reason: reason,
|
||||||
|
UserId: userID,
|
||||||
|
AuthMethodUsed: authMethod,
|
||||||
|
CountryCode: countryCode,
|
||||||
|
CityName: cityName,
|
||||||
|
}
|
||||||
|
}
|
||||||
124
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
124
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultPageSize is the default number of records per page
|
||||||
|
DefaultPageSize = 50
|
||||||
|
// MaxPageSize is the maximum number of records allowed per page
|
||||||
|
MaxPageSize = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccessLogFilter holds pagination and filtering parameters for access logs
|
||||||
|
type AccessLogFilter struct {
|
||||||
|
// Page is the current page number (1-indexed)
|
||||||
|
Page int
|
||||||
|
// PageSize is the number of records per page
|
||||||
|
PageSize int
|
||||||
|
|
||||||
|
// Filtering parameters
|
||||||
|
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||||
|
SourceIP *string // Filter by source IP address
|
||||||
|
Host *string // Filter by host header
|
||||||
|
Path *string // Filter by request path (supports LIKE pattern)
|
||||||
|
UserID *string // Filter by authenticated user ID
|
||||||
|
UserEmail *string // Filter by user email (requires user lookup)
|
||||||
|
UserName *string // Filter by user name (requires user lookup)
|
||||||
|
Method *string // Filter by HTTP method
|
||||||
|
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
|
||||||
|
StatusCode *int // Filter by HTTP status code
|
||||||
|
StartDate *time.Time // Filter by timestamp >= start_date
|
||||||
|
EndDate *time.Time // Filter by timestamp <= end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
|
||||||
|
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||||
|
queryParams := r.URL.Query()
|
||||||
|
|
||||||
|
f.Page = 1
|
||||||
|
if pageStr := queryParams.Get("page"); pageStr != "" {
|
||||||
|
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
||||||
|
f.Page = page
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.PageSize = DefaultPageSize
|
||||||
|
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
|
||||||
|
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
||||||
|
f.PageSize = pageSize
|
||||||
|
if f.PageSize > MaxPageSize {
|
||||||
|
f.PageSize = MaxPageSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if search := queryParams.Get("search"); search != "" {
|
||||||
|
f.Search = &search
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
|
||||||
|
f.SourceIP = &sourceIP
|
||||||
|
}
|
||||||
|
|
||||||
|
if host := queryParams.Get("host"); host != "" {
|
||||||
|
f.Host = &host
|
||||||
|
}
|
||||||
|
|
||||||
|
if path := queryParams.Get("path"); path != "" {
|
||||||
|
f.Path = &path
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID := queryParams.Get("user_id"); userID != "" {
|
||||||
|
f.UserID = &userID
|
||||||
|
}
|
||||||
|
|
||||||
|
if userEmail := queryParams.Get("user_email"); userEmail != "" {
|
||||||
|
f.UserEmail = &userEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
if userName := queryParams.Get("user_name"); userName != "" {
|
||||||
|
f.UserName = &userName
|
||||||
|
}
|
||||||
|
|
||||||
|
if method := queryParams.Get("method"); method != "" {
|
||||||
|
f.Method = &method
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := queryParams.Get("status"); status != "" {
|
||||||
|
f.Status = &status
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
|
||||||
|
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
|
||||||
|
f.StatusCode = &statusCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startDate := queryParams.Get("start_date"); startDate != "" {
|
||||||
|
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
|
||||||
|
if err == nil {
|
||||||
|
f.StartDate = &parsedStartDate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if endDate := queryParams.Get("end_date"); endDate != "" {
|
||||||
|
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
|
||||||
|
if err == nil {
|
||||||
|
f.EndDate = &parsedEndDate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOffset calculates the database offset for pagination
|
||||||
|
func (f *AccessLogFilter) GetOffset() int {
|
||||||
|
return (f.Page - 1) * f.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit returns the page size for database queries
|
||||||
|
func (f *AccessLogFilter) GetLimit() int {
|
||||||
|
return f.PageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
expectedPage int
|
||||||
|
expectedPageSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default values when no params provided",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid page and page_size",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "25",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: 25,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page_size exceeds max, should cap at MaxPageSize",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "200",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: MaxPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "invalid",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "invalid",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "0",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "-1",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "0",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
for key, value := range tt.queryParams {
|
||||||
|
q.Set(key, value)
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
|
||||||
|
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetOffset(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
expectedOffset int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first page",
|
||||||
|
page: 1,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second page",
|
||||||
|
page: 2,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "third page with page size 25",
|
||||||
|
page: 3,
|
||||||
|
pageSize: 25,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page 10 with page size 10",
|
||||||
|
page: 10,
|
||||||
|
pageSize: 10,
|
||||||
|
expectedOffset: 90,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: tt.page,
|
||||||
|
PageSize: tt.pageSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := filter.GetOffset()
|
||||||
|
assert.Equal(t, tt.expectedOffset, offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetLimit(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: 2,
|
||||||
|
PageSize: 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := filter.GetLimit()
|
||||||
|
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||||
|
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager accesslogs.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var filter accesslogs.AccessLogFilter
|
||||||
|
filter.ParseFromRequest(r)
|
||||||
|
|
||||||
|
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiLogs := make([]api.ProxyAccessLog, 0, len(logs))
|
||||||
|
for _, log := range logs {
|
||||||
|
apiLogs = append(apiLogs, *log.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &api.ProxyAccessLogsResponse{
|
||||||
|
Data: apiLogs,
|
||||||
|
Page: filter.Page,
|
||||||
|
PageSize: filter.PageSize,
|
||||||
|
TotalRecords: int(totalCount),
|
||||||
|
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTotalPageCount calculates the total number of pages
|
||||||
|
func getTotalPageCount(totalCount, pageSize int) int {
|
||||||
|
if pageSize <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return (totalCount + pageSize - 1) / pageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
geo geolocation.Geolocation
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
geo: geo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccessLog saves an access log entry to the database after enriching it
|
||||||
|
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||||
|
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
||||||
|
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get location for access log source IP [%s]: %v", logEntry.GeoLocation.ConnectionIP.String(), err)
|
||||||
|
} else {
|
||||||
|
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
|
||||||
|
logEntry.GeoLocation.CityName = location.City.Names.En
|
||||||
|
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"service_id": logEntry.ServiceID,
|
||||||
|
"method": logEntry.Method,
|
||||||
|
"host": logEntry.Host,
|
||||||
|
"path": logEntry.Path,
|
||||||
|
"status": logEntry.StatusCode,
|
||||||
|
}).Errorf("failed to save access log: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||||
|
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, 0, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, totalCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveUserFilters converts user email/name filters to user ID filter
|
||||||
|
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||||
|
if filter.UserEmail == nil && filter.UserName == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var matchingUserIDs []string
|
||||||
|
for _, user := range users {
|
||||||
|
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matchingUserIDs) > 0 {
|
||||||
|
filter.UserID = &matchingUserIDs[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeFree Type = "free"
|
||||||
|
TypeCustom Type = "custom"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Domain struct {
|
||||||
|
ID string `gorm:"unique;primaryKey;autoIncrement"`
|
||||||
|
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
TargetCluster string // The proxy cluster this domain should be validated against
|
||||||
|
Type Type `gorm:"-"`
|
||||||
|
Validated bool
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
|
||||||
|
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||||
|
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||||
|
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||||
|
}
|
||||||
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||||
|
switch t {
|
||||||
|
case domain.TypeCustom:
|
||||||
|
return api.ReverseProxyDomainTypeCustom
|
||||||
|
case domain.TypeFree:
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
// By default return as a "free" domain as that is more restrictive.
|
||||||
|
// TODO: is this correct?
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||||
|
resp := api.ReverseProxyDomain{
|
||||||
|
Domain: d.Domain,
|
||||||
|
Id: d.ID,
|
||||||
|
Type: domainTypeToApi(d.Type),
|
||||||
|
Validated: d.Validated,
|
||||||
|
}
|
||||||
|
if d.TargetCluster != "" {
|
||||||
|
resp.TargetCluster = &d.TargetCluster
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := make([]api.ReverseProxyDomain, 0)
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, domainToApi(d))
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type store interface {
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||||
|
|
||||||
|
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||||
|
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
||||||
|
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
||||||
|
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
||||||
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyURLProvider interface {
|
||||||
|
GetConnectedProxyURLs() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
store store
|
||||||
|
validator domain.Validator
|
||||||
|
proxyURLProvider proxyURLProvider
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||||
|
return Manager{
|
||||||
|
store: store,
|
||||||
|
proxyURLProvider: proxyURLProvider,
|
||||||
|
validator: domain.Validator{
|
||||||
|
Resolver: net.DefaultResolver,
|
||||||
|
},
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ret []*domain.Domain
|
||||||
|
|
||||||
|
// Add connected proxy clusters as free domains.
|
||||||
|
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"proxyAllowList": allowList,
|
||||||
|
}).Debug("getting domains with proxy allow list")
|
||||||
|
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
Domain: cluster,
|
||||||
|
AccountID: accountID,
|
||||||
|
Type: domain.TypeFree,
|
||||||
|
Validated: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom domains.
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
ID: d.ID,
|
||||||
|
Domain: d.Domain,
|
||||||
|
AccountID: accountID,
|
||||||
|
TargetCluster: d.TargetCluster,
|
||||||
|
Type: domain.TypeCustom,
|
||||||
|
Validated: d.Validated,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the target cluster is in the available clusters
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
clusterValid := false
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
if cluster == targetCluster {
|
||||||
|
clusterValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clusterValid {
|
||||||
|
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt an initial validation against the specified cluster only
|
||||||
|
var validated bool
|
||||||
|
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
|
||||||
|
validated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
|
||||||
|
if err != nil {
|
||||||
|
return d, fmt.Errorf("create domain in store: %w", err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
||||||
|
// TODO: check for "no records" type error. Because that is a success condition.
|
||||||
|
return fmt.Errorf("delete domain from store: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).Info("starting domain validation")
|
||||||
|
|
||||||
|
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("get custom domain from store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate only against the domain's target cluster
|
||||||
|
targetCluster := d.TargetCluster
|
||||||
|
if targetCluster == "" {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Warn("domain has no target cluster set, skipping validation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Info("validating domain against target cluster")
|
||||||
|
|
||||||
|
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Info("domain validated successfully")
|
||||||
|
d.Validated = true
|
||||||
|
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).WithError(err).Error("update custom domain in store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Warn("domain validation failed - CNAME does not match target cluster")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||||
|
// their URLs
|
||||||
|
func (m Manager) proxyURLAllowList() []string {
|
||||||
|
var reverseProxyAddresses []string
|
||||||
|
if m.proxyURLProvider != nil {
|
||||||
|
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||||
|
}
|
||||||
|
return reverseProxyAddresses
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||||
|
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||||
|
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||||
|
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||||
|
allowList := m.proxyURLAllowList()
|
||||||
|
if len(allowList) == 0 {
|
||||||
|
return "", fmt.Errorf("no proxy clusters available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
|
||||||
|
return cluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
|
||||||
|
if valid {
|
||||||
|
return targetCluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
|
||||||
|
for _, customDomain := range customDomains {
|
||||||
|
if strings.HasSuffix(domain, "."+customDomain.Domain) {
|
||||||
|
return customDomain.TargetCluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
|
||||||
|
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
|
||||||
|
// It matches the domain suffix against available clusters and returns the matching cluster.
|
||||||
|
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
|
||||||
|
for _, cluster := range availableClusters {
|
||||||
|
if strings.HasSuffix(domain, "."+cluster) {
|
||||||
|
return cluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Validator struct {
|
||||||
|
Resolver resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValidator initializes a validator with a specific DNS Resolver.
|
||||||
|
// If a Validator is used without specifying a Resolver, then it will
|
||||||
|
// use the net.DefaultResolver.
|
||||||
|
func NewValidator(resolver resolver) *Validator {
|
||||||
|
return &Validator{
|
||||||
|
Resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid looks up the CNAME record for the passed domain with a prefix
|
||||||
|
// and compares it against the acceptable domains.
|
||||||
|
// If the returned CNAME matches any accepted domain, it will return true,
|
||||||
|
// otherwise, including in the event of a DNS error, it will return false.
|
||||||
|
// The comparison is very simple, so wildcards will not match if included
|
||||||
|
// in the acceptable domain list.
|
||||||
|
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
|
||||||
|
_, valid := v.ValidateWithCluster(ctx, domain, accept)
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
|
||||||
|
// Returns the cluster address and true if valid, or empty string and false if invalid.
|
||||||
|
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
|
||||||
|
if v.Resolver == nil {
|
||||||
|
v.Resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
lookupDomain := "validation." + domain
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("looking up CNAME for domain validation")
|
||||||
|
|
||||||
|
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
}).WithError(err).Warn("CNAME lookup failed for domain validation")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
nakedCNAME := strings.TrimSuffix(cname, ".")
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": cname,
|
||||||
|
"nakedCNAME": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("CNAME lookup result for domain validation")
|
||||||
|
|
||||||
|
for _, acceptDomain := range accept {
|
||||||
|
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
|
||||||
|
if nakedCNAME == normalizedAccept {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"cluster": acceptDomain,
|
||||||
|
}).Info("domain CNAME matched cluster")
|
||||||
|
return acceptDomain, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Warn("domain CNAME does not match any accepted cluster")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package domain_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver struct {
|
||||||
|
CNAME string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
|
||||||
|
return r.CNAME, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
tests := map[string]struct {
|
||||||
|
resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
domain string
|
||||||
|
accept []string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
"match": {
|
||||||
|
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
"no match": {
|
||||||
|
resolver: resolver{"invalid"},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: false,
|
||||||
|
},
|
||||||
|
"accept trailing dot": {
|
||||||
|
resolver: resolver{"bar.example.com."},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
validator := domain.NewValidator(test.resolver)
|
||||||
|
actual := validator.IsValid(t.Context(), test.domain, test.accept)
|
||||||
|
if test.expect != actual {
|
||||||
|
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
23
management/internals/modules/reverseproxy/interface.go
Normal file
23
management/internals/modules/reverseproxy/interface.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||||
|
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||||
|
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
|
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
|
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||||
|
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||||
|
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||||
|
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||||
|
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||||
|
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||||
|
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
||||||
|
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||||
|
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||||
|
}
|
||||||
225
management/internals/modules/reverseproxy/interface_mock.go
Normal file
225
management/internals/modules/reverseproxy/interface_mock.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: ./interface.go
|
||||||
|
|
||||||
|
// Package reverseproxy is a generated GoMock package.
|
||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockManager is a mock of Manager interface.
|
||||||
|
type MockManager struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockManagerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||||
|
type MockManagerMockRecorder struct {
|
||||||
|
mock *MockManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockManager creates a new mock instance.
|
||||||
|
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||||
|
mock := &MockManager{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockManagerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateService mocks base method.
|
||||||
|
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateService indicates an expected call of CreateService.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteService mocks base method.
|
||||||
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteService indicates an expected call of DeleteService.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountServices mocks base method.
|
||||||
|
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountServices indicates an expected call of GetAccountServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllServices mocks base method.
|
||||||
|
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllServices indicates an expected call of GetAllServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalServices mocks base method.
|
||||||
|
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
|
||||||
|
ret0, _ := ret[0].([]*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalServices indicates an expected call of GetGlobalServices.
|
||||||
|
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetService mocks base method.
|
||||||
|
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetService indicates an expected call of GetService.
|
||||||
|
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByID mocks base method.
|
||||||
|
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByID indicates an expected call of GetServiceByID.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceIDByTargetID mocks base method.
|
||||||
|
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadAllServicesForAccount mocks base method.
|
||||||
|
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
|
||||||
|
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadService mocks base method.
|
||||||
|
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadService indicates an expected call of ReloadService.
|
||||||
|
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt mocks base method.
|
||||||
|
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
|
||||||
|
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus mocks base method.
|
||||||
|
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus indicates an expected call of SetStatus.
|
||||||
|
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateService mocks base method.
|
||||||
|
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateService indicates an expected call of UpdateService.
|
||||||
|
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
||||||
|
}
|
||||||
170
management/internals/modules/reverseproxy/manager/api.go
Normal file
170
management/internals/modules/reverseproxy/manager/api.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
|
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager reverseproxy.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterEndpoints registers all service HTTP endpoints.
|
||||||
|
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||||
|
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
||||||
|
|
||||||
|
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||||
|
|
||||||
|
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiServices := make([]*api.Service, 0, len(allServices))
|
||||||
|
for _, service := range allServices {
|
||||||
|
apiServices = append(apiServices, service.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, apiServices)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ServiceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service := new(reverseproxy.Service)
|
||||||
|
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||||
|
|
||||||
|
if err = service.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ServiceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service := new(reverseproxy.Service)
|
||||||
|
service.ID = serviceID
|
||||||
|
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||||
|
|
||||||
|
if err = service.Validate(); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceID := mux.Vars(r)["serviceId"]
|
||||||
|
if serviceID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
500
management/internals/modules/reverseproxy/manager/manager.go
Normal file
500
management/internals/modules/reverseproxy/manager/manager.go
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const unknownHostPlaceholder = "unknown"
|
||||||
|
|
||||||
|
// ClusterDeriver derives the proxy cluster from a domain.
|
||||||
|
type ClusterDeriver interface {
|
||||||
|
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
accountManager account.Manager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||||
|
clusterDeriver ClusterDeriver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new service manager.
|
||||||
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
proxyGRPCServer: proxyGRPCServer,
|
||||||
|
clusterDeriver: clusterDeriver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||||
|
for _, target := range service.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case reverseproxy.TargetTypePeer:
|
||||||
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = peer.IP.String()
|
||||||
|
case reverseproxy.TargetTypeHost:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Prefix.Addr().String()
|
||||||
|
case reverseproxy.TargetTypeDomain:
|
||||||
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||||
|
target.Host = unknownHostPlaceholder
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target.Host = resource.Domain
|
||||||
|
case reverseproxy.TargetTypeSubnet:
|
||||||
|
// For subnets we do not do any lookups on the resource
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyCluster string
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
service.AccountID = accountID
|
||||||
|
service.ProxyCluster = proxyCluster
|
||||||
|
service.InitNewRecord()
|
||||||
|
err = service.Auth.HashSecrets()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate session JWT signing keys
|
||||||
|
keyPair, err := sessionkey.GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate session keys: %w", err)
|
||||||
|
}
|
||||||
|
service.SessionPrivateKey = keyPair.PrivateKey
|
||||||
|
service.SessionPublicKey = keyPair.PublicKey
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
// Check for duplicate domain
|
||||||
|
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("failed to check existing service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existingService != nil {
|
||||||
|
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to create service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldCluster string
|
||||||
|
var domainChanged bool
|
||||||
|
var serviceEnabledChanged bool
|
||||||
|
|
||||||
|
err = service.Auth.HashSecrets()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldCluster = existingService.ProxyCluster
|
||||||
|
|
||||||
|
if existingService.Domain != service.Domain {
|
||||||
|
domainChanged = true
|
||||||
|
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
|
return fmt.Errorf("check existing service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conflictService != nil && conflictService.ID != service.ID {
|
||||||
|
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.clusterDeriver != nil {
|
||||||
|
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||||
|
}
|
||||||
|
service.ProxyCluster = newCluster
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
service.ProxyCluster = existingService.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||||
|
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||||
|
service.Auth.PasswordAuth.Password == "" {
|
||||||
|
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||||
|
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||||
|
service.Auth.PinAuth.Pin == "" {
|
||||||
|
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta = existingService.Meta
|
||||||
|
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||||
|
service.SessionPublicKey = existingService.SessionPublicKey
|
||||||
|
serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||||
|
|
||||||
|
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("update service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||||
|
switch {
|
||||||
|
case domainChanged && oldCluster != service.ProxyCluster:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||||
|
case !service.Enabled && serviceEnabledChanged:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||||
|
case service.Enabled && serviceEnabledChanged:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||||
|
default:
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||||
|
}
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||||
|
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||||
|
for _, target := range targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case reverseproxy.TargetTypePeer:
|
||||||
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||||
|
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var service *reverseproxy.Service
|
||||||
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
var err error
|
||||||
|
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||||
|
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||||
|
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.CertificateIssuedAt = time.Now()
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||||
|
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||||
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Meta.Status = string(status)
|
||||||
|
|
||||||
|
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||||
|
return fmt.Errorf("failed to update service status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||||
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||||
|
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return target.ServiceID, nil
|
||||||
|
}
|
||||||
463
management/internals/modules/reverseproxy/reverseproxy.go
Normal file
463
management/internals/modules/reverseproxy/reverseproxy.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Operation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Create Operation = "create"
|
||||||
|
Update Operation = "update"
|
||||||
|
Delete Operation = "delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusPending ProxyStatus = "pending"
|
||||||
|
StatusActive ProxyStatus = "active"
|
||||||
|
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||||
|
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||||
|
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||||
|
StatusError ProxyStatus = "error"
|
||||||
|
|
||||||
|
TargetTypePeer = "peer"
|
||||||
|
TargetTypeHost = "host"
|
||||||
|
TargetTypeDomain = "domain"
|
||||||
|
TargetTypeSubnet = "subnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Target struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"-"`
|
||||||
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
|
Path *string `json:"path,omitempty"`
|
||||||
|
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||||
|
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||||
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
|
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||||
|
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PasswordAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PINAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Pin string `json:"pin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BearerAuthConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) HashSecrets() error {
|
||||||
|
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
||||||
|
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
a.PasswordAuth.Password = hashedPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
||||||
|
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash pin: %w", err)
|
||||||
|
}
|
||||||
|
a.PinAuth.Pin = hashedPin
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthConfig) ClearSecrets() {
|
||||||
|
if a.PasswordAuth != nil {
|
||||||
|
a.PasswordAuth.Password = ""
|
||||||
|
}
|
||||||
|
if a.PinAuth != nil {
|
||||||
|
a.PinAuth.Pin = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCValidationConfig struct {
|
||||||
|
Issuer string
|
||||||
|
Audiences []string
|
||||||
|
KeysLocation string
|
||||||
|
MaxTokenAgeSeconds int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceMeta struct {
|
||||||
|
CreatedAt time.Time
|
||||||
|
CertificateIssuedAt time.Time
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
Name string
|
||||||
|
Domain string `gorm:"index"`
|
||||||
|
ProxyCluster string `gorm:"index"`
|
||||||
|
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||||
|
Enabled bool
|
||||||
|
PassHostHeader bool
|
||||||
|
RewriteRedirects bool
|
||||||
|
Auth AuthConfig `gorm:"serializer:json"`
|
||||||
|
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||||
|
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||||
|
for _, target := range targets {
|
||||||
|
target.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Service{
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: name,
|
||||||
|
Domain: domain,
|
||||||
|
ProxyCluster: proxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: enabled,
|
||||||
|
}
|
||||||
|
s.InitNewRecord()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||||
|
// Service record. This overwrites any existing ID and Meta fields and should
|
||||||
|
// only be called during initial creation, not for updates.
|
||||||
|
func (s *Service) InitNewRecord() {
|
||||||
|
s.ID = xid.New().String()
|
||||||
|
s.Meta = ServiceMeta{
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Status: string(StatusPending),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToAPIResponse() *api.Service {
|
||||||
|
s.Auth.ClearSecrets()
|
||||||
|
|
||||||
|
authConfig := api.ServiceAuthConfig{}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil {
|
||||||
|
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||||
|
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: s.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil {
|
||||||
|
authConfig.PinAuth = &api.PINAuthConfig{
|
||||||
|
Enabled: s.Auth.PinAuth.Enabled,
|
||||||
|
Pin: s.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil {
|
||||||
|
authConfig.BearerAuth = &api.BearerAuthConfig{
|
||||||
|
Enabled: s.Auth.BearerAuth.Enabled,
|
||||||
|
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert internal targets to API targets
|
||||||
|
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
apiTargets = append(apiTargets, api.ServiceTarget{
|
||||||
|
Path: target.Path,
|
||||||
|
Host: &target.Host,
|
||||||
|
Port: target.Port,
|
||||||
|
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||||
|
TargetId: target.TargetId,
|
||||||
|
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||||
|
Enabled: target.Enabled,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := api.ServiceMeta{
|
||||||
|
CreatedAt: s.Meta.CreatedAt,
|
||||||
|
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.Meta.CertificateIssuedAt.IsZero() {
|
||||||
|
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &api.Service{
|
||||||
|
Id: s.ID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Targets: apiTargets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: &s.PassHostHeader,
|
||||||
|
RewriteRedirects: &s.RewriteRedirects,
|
||||||
|
Auth: authConfig,
|
||||||
|
Meta: meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ProxyCluster != "" {
|
||||||
|
resp.ProxyCluster = &s.ProxyCluster
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
||||||
|
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||||
|
for _, target := range s.Targets {
|
||||||
|
if !target.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Make path prefix stripping configurable per-target.
|
||||||
|
// Currently the matching prefix is baked into the target URL path,
|
||||||
|
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||||
|
targetURL := url.URL{
|
||||||
|
Scheme: target.Protocol,
|
||||||
|
Host: target.Host,
|
||||||
|
Path: "/", // TODO: support service path
|
||||||
|
}
|
||||||
|
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||||
|
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/"
|
||||||
|
if target.Path != nil {
|
||||||
|
path = *target.Path
|
||||||
|
}
|
||||||
|
pathMappings = append(pathMappings, &proto.PathMapping{
|
||||||
|
Path: path,
|
||||||
|
Target: targetURL.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := &proto.Authentication{
|
||||||
|
SessionKey: s.SessionPublicKey,
|
||||||
|
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
||||||
|
auth.Password = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
||||||
|
auth.Pin = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||||
|
auth.Oidc = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.ProxyMapping{
|
||||||
|
Type: operationToProtoType(operation),
|
||||||
|
Id: s.ID,
|
||||||
|
Domain: s.Domain,
|
||||||
|
Path: pathMappings,
|
||||||
|
AuthToken: authToken,
|
||||||
|
Auth: auth,
|
||||||
|
AccountId: s.AccountID,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||||
|
switch op {
|
||||||
|
case Create:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
case Update:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||||
|
case Delete:
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||||
|
default:
|
||||||
|
log.Fatalf("unknown operation type: %v", op)
|
||||||
|
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||||
|
// (443 for https, 80 for http).
|
||||||
|
func isDefaultPort(scheme string, port int) bool {
|
||||||
|
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||||
|
s.Name = req.Name
|
||||||
|
s.Domain = req.Domain
|
||||||
|
s.AccountID = accountID
|
||||||
|
|
||||||
|
targets := make([]*Target, 0, len(req.Targets))
|
||||||
|
for _, apiTarget := range req.Targets {
|
||||||
|
target := &Target{
|
||||||
|
AccountID: accountID,
|
||||||
|
Path: apiTarget.Path,
|
||||||
|
Port: apiTarget.Port,
|
||||||
|
Protocol: string(apiTarget.Protocol),
|
||||||
|
TargetId: apiTarget.TargetId,
|
||||||
|
TargetType: string(apiTarget.TargetType),
|
||||||
|
Enabled: apiTarget.Enabled,
|
||||||
|
}
|
||||||
|
if apiTarget.Host != nil {
|
||||||
|
target.Host = *apiTarget.Host
|
||||||
|
}
|
||||||
|
targets = append(targets, target)
|
||||||
|
}
|
||||||
|
s.Targets = targets
|
||||||
|
|
||||||
|
s.Enabled = req.Enabled
|
||||||
|
|
||||||
|
if req.PassHostHeader != nil {
|
||||||
|
s.PassHostHeader = *req.PassHostHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.RewriteRedirects != nil {
|
||||||
|
s.RewriteRedirects = *req.RewriteRedirects
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PasswordAuth != nil {
|
||||||
|
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||||
|
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||||
|
Password: req.Auth.PasswordAuth.Password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.PinAuth != nil {
|
||||||
|
s.Auth.PinAuth = &PINAuthConfig{
|
||||||
|
Enabled: req.Auth.PinAuth.Enabled,
|
||||||
|
Pin: req.Auth.PinAuth.Pin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Auth.BearerAuth != nil {
|
||||||
|
bearerAuth := &BearerAuthConfig{
|
||||||
|
Enabled: req.Auth.BearerAuth.Enabled,
|
||||||
|
}
|
||||||
|
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||||
|
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||||
|
}
|
||||||
|
s.Auth.BearerAuth = bearerAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Validate() error {
|
||||||
|
if s.Name == "" {
|
||||||
|
return errors.New("service name is required")
|
||||||
|
}
|
||||||
|
if len(s.Name) > 255 {
|
||||||
|
return errors.New("service name exceeds maximum length of 255 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Domain == "" {
|
||||||
|
return errors.New("service domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.Targets) == 0 {
|
||||||
|
return errors.New("at least one target is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
switch target.TargetType {
|
||||||
|
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||||
|
// host field will be ignored
|
||||||
|
case TargetTypeSubnet:
|
||||||
|
if target.Host == "" {
|
||||||
|
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||||
|
}
|
||||||
|
if target.TargetId == "" {
|
||||||
|
return fmt.Errorf("target %d has empty target_id", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Copy() *Service {
|
||||||
|
targets := make([]*Target, len(s.Targets))
|
||||||
|
for i, target := range s.Targets {
|
||||||
|
targetCopy := *target
|
||||||
|
targets[i] = &targetCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
ID: s.ID,
|
||||||
|
AccountID: s.AccountID,
|
||||||
|
Name: s.Name,
|
||||||
|
Domain: s.Domain,
|
||||||
|
ProxyCluster: s.ProxyCluster,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: s.Enabled,
|
||||||
|
PassHostHeader: s.PassHostHeader,
|
||||||
|
RewriteRedirects: s.RewriteRedirects,
|
||||||
|
Auth: s.Auth,
|
||||||
|
Meta: s.Meta,
|
||||||
|
SessionPrivateKey: s.SessionPrivateKey,
|
||||||
|
SessionPublicKey: s.SessionPublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
405
management/internals/modules/reverseproxy/reverseproxy_test.go
Normal file
405
management/internals/modules/reverseproxy/reverseproxy_test.go
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validProxy() *Service {
|
||||||
|
return &Service{
|
||||||
|
Name: "test",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_Valid(t *testing.T) {
|
||||||
|
require.NoError(t, validProxy().Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyName(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Name = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyDomain(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Domain = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_NoTargets(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = nil
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetId = ""
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_InvalidTargetType(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].TargetType = "invalid"
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_ResourceTarget(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "resource-1",
|
||||||
|
TargetType: TargetTypeHost,
|
||||||
|
Host: "example.org",
|
||||||
|
Port: 443,
|
||||||
|
Protocol: "https",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, rp.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets = append(rp.Targets, &Target{
|
||||||
|
TargetId: "",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.2",
|
||||||
|
Port: 80,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
err := rp.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "target 1")
|
||||||
|
assert.Contains(t, err.Error(), "empty target_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDefaultPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
scheme string
|
||||||
|
port int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"http", 80, true},
|
||||||
|
{"https", 443, true},
|
||||||
|
{"http", 443, false},
|
||||||
|
{"https", 80, false},
|
||||||
|
{"http", 8080, false},
|
||||||
|
{"https", 8443, false},
|
||||||
|
{"http", 0, false},
|
||||||
|
{"https", 0, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||||
|
oidcConfig := OIDCValidationConfig{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
wantTarget string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "http with default port 80 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with default port 443 omits port",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "https://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port 0 omits port",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 0,
|
||||||
|
wantTarget: "http://10.0.0.1/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-default port is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8080,
|
||||||
|
wantTarget: "http://10.0.0.1:8080/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with non-default port is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 8443,
|
||||||
|
wantTarget: "https://10.0.0.1:8443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http port 443 is included",
|
||||||
|
protocol: "http",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 443,
|
||||||
|
wantTarget: "http://10.0.0.1:443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https port 80 is included",
|
||||||
|
protocol: "https",
|
||||||
|
host: "10.0.0.1",
|
||||||
|
port: 80,
|
||||||
|
wantTarget: "https://10.0.0.1:80/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: tt.host,
|
||||||
|
Port: tt.port,
|
||||||
|
Protocol: tt.protocol,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
||||||
|
require.Len(t, pm.Path, 1, "should have one path mapping")
|
||||||
|
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "test-id",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
||||||
|
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
tests := []struct {
|
||||||
|
op Operation
|
||||||
|
want proto.ProxyMappingUpdateType
|
||||||
|
}{
|
||||||
|
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
||||||
|
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
||||||
|
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.op), func(t *testing.T) {
|
||||||
|
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
||||||
|
assert.Equal(t, tt.want, pm.Type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *AuthConfig
|
||||||
|
wantErr bool
|
||||||
|
validate func(*testing.T, *AuthConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hash password successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "testPassword123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash PIN successfully",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "123456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
// Verify the hash can be verified
|
||||||
|
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("Hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash both password and PIN",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "9999",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||||
|
t.Errorf("Password not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN not hashed with argon2id")
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
||||||
|
t.Errorf("Password hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
||||||
|
t.Errorf("PIN hash verification failed: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip disabled password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "password" {
|
||||||
|
t.Errorf("Disabled password auth should not be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip empty password",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Empty password should remain empty")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skip nil password auth",
|
||||||
|
config: &AuthConfig{
|
||||||
|
PasswordAuth: nil,
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "1234",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
validate: func(t *testing.T, config *AuthConfig) {
|
||||||
|
if config.PasswordAuth != nil {
|
||||||
|
t.Errorf("PasswordAuth should remain nil")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||||
|
t.Errorf("PIN should still be hashed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.HashSecrets()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, tt.config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "correctPassword",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.HashSecrets(); err != nil {
|
||||||
|
t.Fatalf("HashSecrets() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify with wrong password should fail
|
||||||
|
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
||||||
|
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||||
|
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||||
|
config := &AuthConfig{
|
||||||
|
PasswordAuth: &PasswordAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Password: "hashedPassword",
|
||||||
|
},
|
||||||
|
PinAuth: &PINAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Pin: "hashedPin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ClearSecrets()
|
||||||
|
|
||||||
|
if config.PasswordAuth.Password != "" {
|
||||||
|
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
||||||
|
}
|
||||||
|
if config.PinAuth.Pin != "" {
|
||||||
|
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package sessionkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
PrivateKey string
|
||||||
|
PublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Method auth.Method `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKeyPair() (*KeyPair, error) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeyPair{
|
||||||
|
PrivateKey: base64.StdEncoding.EncodeToString(priv),
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||||
|
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(privKeyBytes) != ed25519.PrivateKeySize {
|
||||||
|
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey := ed25519.PrivateKey(privKeyBytes)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: auth.SessionJWTIssuer,
|
||||||
|
Subject: userID,
|
||||||
|
Audience: jwt.ClaimStrings{domain},
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
Method: method,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
signedToken, err := token.SignedString(privKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("sign token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken, nil
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ func (r *Record) Validate() error {
|
|||||||
return errors.New("record name is required")
|
return errors.New("record name is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !util.IsValidDomain(r.Name) {
|
if !domain.IsValidDomain(r.Name) {
|
||||||
return errors.New("invalid record name format")
|
return errors.New("invalid record name format")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,8 +81,8 @@ func (r *Record) Validate() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case RecordTypeCNAME:
|
case RecordTypeCNAME:
|
||||||
if !util.IsValidDomain(r.Content) {
|
if !domain.IsValidDomainNoWildcard(r.Content) {
|
||||||
return errors.New("invalid CNAME record format")
|
return errors.New("invalid CNAME target format")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ func (z *Zone) Validate() error {
|
|||||||
return errors.New("zone name exceeds maximum length of 255 characters")
|
return errors.New("zone name exceeds maximum length of 255 characters")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !util.IsValidDomain(z.Domain) {
|
if !domain.IsValidDomainNoWildcard(z.Domain) {
|
||||||
return errors.New("invalid zone domain format")
|
return errors.New("invalid zone domain format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ import (
|
|||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -92,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -120,11 +122,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||||
}
|
}
|
||||||
|
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
||||||
|
s.proxyAuthClose = proxyAuthClose
|
||||||
gRPCOpts := []grpc.ServerOption{
|
gRPCOpts := []grpc.ServerOption{
|
||||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||||
grpc.KeepaliveParams(kasp),
|
grpc.KeepaliveParams(kasp),
|
||||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
||||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
@@ -150,10 +154,53 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
}
|
}
|
||||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||||
|
|
||||||
|
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||||
|
log.Info("ProxyService registered on gRPC server")
|
||||||
|
|
||||||
return gRPCAPIHandler
|
return gRPCAPIHandler
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
proxyService.SetProxyManager(s.ReverseProxyManager())
|
||||||
|
})
|
||||||
|
return proxyService
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||||
|
return Create(s, func() nbgrpc.ProxyOIDCConfig {
|
||||||
|
return nbgrpc.ProxyOIDCConfig{
|
||||||
|
Issuer: s.Config.HttpConfig.AuthIssuer,
|
||||||
|
// todo: double check auth clientID value
|
||||||
|
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
|
||||||
|
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
|
||||||
|
Audience: s.Config.HttpConfig.AuthAudience,
|
||||||
|
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||||
|
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||||
|
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||||
|
log.Info("One-time token store initialized for proxy authentication")
|
||||||
|
return tokenStore
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||||
|
return Create(s, func() accesslogs.Manager {
|
||||||
|
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||||
|
return accessLogManager
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||||
// Load server's certificate and private key
|
// Load server's certificate and private key
|
||||||
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
||||||
|
|||||||
@@ -100,6 +100,8 @@ type HttpServerConfig struct {
|
|||||||
CertFile string
|
CertFile string
|
||||||
// CertKey is the location of the certificate private key
|
// CertKey is the location of the certificate private key
|
||||||
CertKey string
|
CertKey string
|
||||||
|
// AuthClientID is the client id used for proxy SSO auth
|
||||||
|
AuthClientID string
|
||||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||||
AuthAudience string
|
AuthAudience string
|
||||||
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
||||||
@@ -117,6 +119,8 @@ type HttpServerConfig struct {
|
|||||||
IdpSignKeyRefreshEnabled bool
|
IdpSignKeyRefreshEnabled bool
|
||||||
// Extra audience
|
// Extra audience
|
||||||
ExtraAuthAudience string
|
ExtraAuthAudience string
|
||||||
|
// AuthCallbackDomain contains the callback domain
|
||||||
|
AuthCallbackURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||||
@@ -69,7 +72,14 @@ func (s *BaseServer) UsersManager() users.Manager {
|
|||||||
func (s *BaseServer) SettingsManager() settings.Manager {
|
func (s *BaseServer) SettingsManager() settings.Manager {
|
||||||
return Create(s, func() settings.Manager {
|
return Create(s, func() settings.Manager {
|
||||||
extraSettingsManager := integrations.NewManager(s.EventStore())
|
extraSettingsManager := integrations.NewManager(s.EventStore())
|
||||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
|
|
||||||
|
idpConfig := settings.IdpConfig{}
|
||||||
|
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||||
|
idpConfig.EmbeddedIdpEnabled = true
|
||||||
|
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create account manager: %v", err)
|
log.Fatalf("failed to create account manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
accountManager.SetServiceManager(s.ReverseProxyManager())
|
||||||
|
})
|
||||||
|
|
||||||
return accountManager
|
return accountManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -147,7 +162,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||||
return Create(s, func() resources.Manager {
|
return Create(s, func() resources.Manager {
|
||||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
|
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,3 +189,16 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
||||||
|
return Create(s, func() reverseproxy.Manager {
|
||||||
|
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||||
|
return Create(s, func() *manager.Manager {
|
||||||
|
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
||||||
|
return &m
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
@@ -21,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/metrics"
|
"github.com/netbirdio/netbird/management/server/metrics"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
@@ -58,6 +58,8 @@ type BaseServer struct {
|
|||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
|
|
||||||
|
proxyAuthClose func()
|
||||||
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
certManager *autocert.Manager
|
certManager *autocert.Manager
|
||||||
update *version.Update
|
update *version.Update
|
||||||
@@ -215,6 +217,11 @@ func (s *BaseServer) Stop() error {
|
|||||||
_ = s.certManager.Listener().Close()
|
_ = s.certManager.Listener().Close()
|
||||||
}
|
}
|
||||||
s.GRPCServer().Stop()
|
s.GRPCServer().Stop()
|
||||||
|
s.ReverseProxyGRPCServer().Close()
|
||||||
|
if s.proxyAuthClose != nil {
|
||||||
|
s.proxyAuthClose()
|
||||||
|
s.proxyAuthClose = nil
|
||||||
|
}
|
||||||
_ = s.Store().Close(ctx)
|
_ = s.Store().Close(ctx)
|
||||||
_ = s.EventStore().Close(ctx)
|
_ = s.EventStore().Close(ctx)
|
||||||
if s.update != nil {
|
if s.update != nil {
|
||||||
|
|||||||
167
management/internals/shared/grpc/onetime_token.go
Normal file
167
management/internals/shared/grpc/onetime_token.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||||
|
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||||
|
// a service is created and must be used exactly once by the proxy
|
||||||
|
// to authenticate a subsequent RPC call.
|
||||||
|
type OneTimeTokenStore struct {
|
||||||
|
tokens map[string]*tokenMetadata
|
||||||
|
mu sync.RWMutex
|
||||||
|
cleanup *time.Ticker
|
||||||
|
cleanupDone chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenMetadata stores information about a one-time token
|
||||||
|
type tokenMetadata struct {
|
||||||
|
ServiceID string
|
||||||
|
AccountID string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||||
|
// of expired tokens. The cleanupInterval determines how often expired
|
||||||
|
// tokens are removed from memory.
|
||||||
|
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||||
|
store := &OneTimeTokenStore{
|
||||||
|
tokens: make(map[string]*tokenMetadata),
|
||||||
|
cleanup: time.NewTicker(cleanupInterval),
|
||||||
|
cleanupDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start background cleanup goroutine
|
||||||
|
go store.cleanupExpired()
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken creates a new cryptographically secure one-time token
|
||||||
|
// with the specified TTL. The token is associated with a specific
|
||||||
|
// accountID and serviceID for validation purposes.
|
||||||
|
//
|
||||||
|
// Returns the generated token string or an error if random generation fails.
|
||||||
|
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||||
|
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||||
|
randomBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(randomBytes); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||||
|
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.tokens[token] = &tokenMetadata{
|
||||||
|
ServiceID: serviceID,
|
||||||
|
AccountID: accountID,
|
||||||
|
ExpiresAt: time.Now().Add(ttl),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||||
|
serviceID, accountID, ttl)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndConsume verifies the token against the provided accountID and
|
||||||
|
// serviceID, checks expiration, and then deletes it to enforce single-use.
|
||||||
|
//
|
||||||
|
// This method uses constant-time comparison to prevent timing attacks.
|
||||||
|
//
|
||||||
|
// Returns nil on success, or an error if:
|
||||||
|
// - Token doesn't exist
|
||||||
|
// - Token has expired
|
||||||
|
// - Account ID doesn't match
|
||||||
|
// - Reverse proxy ID doesn't match
|
||||||
|
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
metadata, exists := s.tokens[token]
|
||||||
|
if !exists {
|
||||||
|
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||||
|
serviceID, accountID)
|
||||||
|
return fmt.Errorf("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiration
|
||||||
|
if time.Now().After(metadata.ExpiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||||
|
serviceID, accountID)
|
||||||
|
return fmt.Errorf("token expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||||
|
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||||
|
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||||
|
metadata.AccountID, accountID)
|
||||||
|
return fmt.Errorf("account ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate service ID using constant-time comparison
|
||||||
|
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||||
|
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||||
|
metadata.ServiceID, serviceID)
|
||||||
|
return fmt.Errorf("service ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete token immediately to enforce single-use
|
||||||
|
delete(s.tokens, token)
|
||||||
|
|
||||||
|
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||||
|
serviceID, accountID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||||
|
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.cleanup.C:
|
||||||
|
s.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
removed := 0
|
||||||
|
for token, metadata := range s.tokens {
|
||||||
|
if now.After(metadata.ExpiresAt) {
|
||||||
|
delete(s.tokens, token)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if removed > 0 {
|
||||||
|
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
case <-s.cleanupDone:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup goroutine and releases resources
|
||||||
|
func (s *OneTimeTokenStore) Close() {
|
||||||
|
s.cleanup.Stop()
|
||||||
|
close(s.cleanupDone)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||||
|
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return len(s.tokens)
|
||||||
|
}
|
||||||
1065
management/internals/shared/grpc/proxy.go
Normal file
1065
management/internals/shared/grpc/proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
234
management/internals/shared/grpc/proxy_auth.go
Normal file
234
management/internals/shared/grpc/proxy_auth.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
|
||||||
|
lastUsedUpdateInterval = time.Minute
|
||||||
|
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
|
||||||
|
lastUsedCleanupInterval = 2 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyTokenContextKey struct{}
|
||||||
|
|
||||||
|
// ProxyTokenContextKey is the typed key used to store validated token info in context.
|
||||||
|
var ProxyTokenContextKey = proxyTokenContextKey{}
|
||||||
|
|
||||||
|
// proxyTokenID identifies a proxy access token by its database ID.
|
||||||
|
type proxyTokenID = string
|
||||||
|
|
||||||
|
// proxyTokenStore defines the store interface needed for token validation
|
||||||
|
type proxyTokenStore interface {
|
||||||
|
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||||
|
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyAuthInterceptor holds state for proxy authentication interceptors.
|
||||||
|
type proxyAuthInterceptor struct {
|
||||||
|
store proxyTokenStore
|
||||||
|
failureLimiter *authFailureLimiter
|
||||||
|
|
||||||
|
// lastUsedMu protects lastUsedTimes
|
||||||
|
lastUsedMu sync.Mutex
|
||||||
|
lastUsedTimes map[proxyTokenID]time.Time
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
i := &proxyAuthInterceptor{
|
||||||
|
store: tokenStore,
|
||||||
|
failureLimiter: newAuthFailureLimiter(),
|
||||||
|
lastUsedTimes: make(map[proxyTokenID]time.Time),
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go i.lastUsedCleanupLoop(ctx)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
|
||||||
|
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
|
||||||
|
// The returned close function must be called on shutdown to stop background goroutines.
|
||||||
|
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
|
||||||
|
interceptor := newProxyAuthInterceptor(tokenStore)
|
||||||
|
|
||||||
|
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||||
|
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := interceptor.validateProxyToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||||
|
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := interceptor.validateProxyToken(ss.Context())
|
||||||
|
if err != nil {
|
||||||
|
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||||
|
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
|
||||||
|
wrapped := &wrappedServerStream{
|
||||||
|
ServerStream: ss,
|
||||||
|
ctx: ctx,
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(srv, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unary, stream, interceptor.close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
|
clientIP := peerIPFromContext(ctx)
|
||||||
|
|
||||||
|
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||||
|
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := i.doValidateProxyToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if clientIP != "" {
|
||||||
|
i.failureLimiter.recordFailure(clientIP)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i.maybeUpdateLastUsed(ctx, token.ID)
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValues := md.Get("authorization")
|
||||||
|
if len(authValues) == 0 {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValue := authValues[0]
|
||||||
|
if !strings.HasPrefix(authValue, "Bearer ") {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
|
||||||
|
}
|
||||||
|
|
||||||
|
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
|
||||||
|
|
||||||
|
if err := plainToken.Validate(); err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||||
|
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||||
|
|
||||||
|
if !token.IsValid() {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
|
||||||
|
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
i.lastUsedMu.Lock()
|
||||||
|
lastUpdate, exists := i.lastUsedTimes[tokenID]
|
||||||
|
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
|
||||||
|
i.lastUsedMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
i.lastUsedTimes[tokenID] = now
|
||||||
|
i.lastUsedMu.Unlock()
|
||||||
|
|
||||||
|
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(lastUsedCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
i.cleanupStaleLastUsed()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupStaleLastUsed removes entries older than 2x the update interval.
|
||||||
|
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
|
||||||
|
i.lastUsedMu.Lock()
|
||||||
|
defer i.lastUsedMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
staleThreshold := 2 * lastUsedUpdateInterval
|
||||||
|
for id, lastUpdate := range i.lastUsedTimes {
|
||||||
|
if now.Sub(lastUpdate) > staleThreshold {
|
||||||
|
delete(i.lastUsedTimes, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *proxyAuthInterceptor) close() {
|
||||||
|
i.cancel()
|
||||||
|
i.failureLimiter.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyTokenFromContext retrieves the validated proxy token from the context
|
||||||
|
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
|
||||||
|
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
|
||||||
|
type wrappedServerStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedServerStream) Context() context.Context {
|
||||||
|
return w.ctx
|
||||||
|
}
|
||||||
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
|
||||||
|
proxyAuthFailureBurst = 5
|
||||||
|
// proxyAuthLimiterCleanup is how often stale limiters are removed.
|
||||||
|
proxyAuthLimiterCleanup = 5 * time.Minute
|
||||||
|
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
|
||||||
|
proxyAuthLimiterTTL = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
|
||||||
|
// One token every 12 seconds = 5 per minute.
|
||||||
|
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
|
||||||
|
|
||||||
|
// clientIP identifies a client by its IP address for rate limiting purposes.
|
||||||
|
type clientIP = string
|
||||||
|
|
||||||
|
type limiterEntry struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastAccess time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
|
||||||
|
type authFailureLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limiters map[clientIP]*limiterEntry
|
||||||
|
failureRate rate.Limit
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthFailureLimiter() *authFailureLimiter {
|
||||||
|
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l := &authFailureLimiter{
|
||||||
|
limiters: make(map[clientIP]*limiterEntry),
|
||||||
|
failureRate: failureRate,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go l.cleanupLoop(ctx)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// isLimited returns true if the given IP has exhausted its failure budget.
|
||||||
|
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
entry, exists := l.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.limiter.Tokens() < 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordFailure consumes a token from the rate limiter for the given IP.
|
||||||
|
func (l *authFailureLimiter) recordFailure(ip clientIP) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
entry, exists := l.limiters[ip]
|
||||||
|
if !exists {
|
||||||
|
entry = &limiterEntry{
|
||||||
|
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
|
||||||
|
}
|
||||||
|
l.limiters[ip] = entry
|
||||||
|
}
|
||||||
|
entry.lastAccess = now
|
||||||
|
entry.limiter.Allow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(proxyAuthLimiterCleanup)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
l.cleanup()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) cleanup() {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for ip, entry := range l.limiters {
|
||||||
|
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
|
||||||
|
delete(l.limiters, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authFailureLimiter) stop() {
|
||||||
|
l.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerIPFromContext extracts the client IP from the gRPC context.
|
||||||
|
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||||
|
func peerIPFromContext(ctx context.Context) clientIP {
|
||||||
|
if addr, ok := realip.FromContext(ctx); ok {
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p, ok := peer.FromContext(ctx); ok {
|
||||||
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||||
|
if err != nil {
|
||||||
|
return p.Addr.String()
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
ip := "192.168.1.1"
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure("192.168.1.1")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, l.isLimited("192.168.1.1"))
|
||||||
|
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
ip := "10.0.0.1"
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||||
|
l.recordFailure(ip)
|
||||||
|
}
|
||||||
|
require.True(t, l.isLimited(ip))
|
||||||
|
|
||||||
|
// Wait for token replenishment
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
l.recordFailure("10.0.0.1")
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
require.Len(t, l.limiters, 1)
|
||||||
|
// Backdate the entry so it looks stale
|
||||||
|
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
l.cleanup()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
||||||
|
l := newAuthFailureLimiter()
|
||||||
|
defer l.stop()
|
||||||
|
|
||||||
|
l.recordFailure("10.0.0.1")
|
||||||
|
l.recordFailure("10.0.0.2")
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
// Only backdate one entry
|
||||||
|
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
l.cleanup()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
||||||
|
assert.Contains(t, l.limiters, "10.0.0.2")
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
381
management/internals/shared/grpc/proxy_group_access_test.go
Normal file
381
management/internals/shared/grpc/proxy_group_access_test.go
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockReverseProxyManager struct {
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.proxiesByAccount[accountID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
|
return []*reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||||
|
return &reverseproxy.Service{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockUsersManager struct {
|
||||||
|
users map[string]*types.User
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
user, ok := m.users[userID]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("user not found")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
userID string
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
users map[string]*types.User
|
||||||
|
proxyErr error
|
||||||
|
userErr error
|
||||||
|
expectErr bool
|
||||||
|
expectErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user not found",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "unknown-user",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "user not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy not found in user's account",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "service not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy exists in different account - not accessible",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "service not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no bearer auth configured - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bearer auth disabled - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user not in allowed groups",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "not in allowed groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in one of the allowed groups - allow access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in all allowed groups - allow access",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{
|
||||||
|
Domain: "app.example.com",
|
||||||
|
AccountID: "account1",
|
||||||
|
Auth: reverseproxy.AuthConfig{
|
||||||
|
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DistributionGroups: []string{"group1", "group2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy manager error",
|
||||||
|
domain: "app.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: nil,
|
||||||
|
proxyErr: errors.New("database error"),
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
expectErrMsg: "get account services",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies in account - finds correct one",
|
||||||
|
domain: "app2.example.com",
|
||||||
|
userID: "user1",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {
|
||||||
|
{Domain: "app1.example.com", AccountID: "account1"},
|
||||||
|
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
||||||
|
{Domain: "app3.example.com", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
users: map[string]*types.User{
|
||||||
|
"user1": {Id: "user1", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
reverseProxyManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
|
err: tt.proxyErr,
|
||||||
|
},
|
||||||
|
usersManager: &mockUsersManager{
|
||||||
|
users: tt.users,
|
||||||
|
err: tt.userErr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.expectErrMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
domain string
|
||||||
|
proxiesByAccount map[string][]*reverseproxy.Service
|
||||||
|
err error
|
||||||
|
expectProxy bool
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "proxy found",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {
|
||||||
|
{Domain: "other.example.com", AccountID: "account1"},
|
||||||
|
{Domain: "app.example.com", AccountID: "account1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectProxy: true,
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy not found in account",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "unknown.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||||
|
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||||
|
},
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty proxy list for account",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manager error",
|
||||||
|
accountID: "account1",
|
||||||
|
domain: "app.example.com",
|
||||||
|
proxiesByAccount: nil,
|
||||||
|
err: errors.New("database error"),
|
||||||
|
expectProxy: false,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
reverseProxyManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: tt.proxiesByAccount,
|
||||||
|
err: tt.err,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, proxy)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, proxy)
|
||||||
|
assert.Equal(t, tt.domain, proxy.Domain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
232
management/internals/shared/grpc/proxy_test.go
Normal file
232
management/internals/shared/grpc/proxy_test.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||||
|
// and returns the channel where messages will be received.
|
||||||
|
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
||||||
|
ch := make(chan *proto.ProxyMapping, 10)
|
||||||
|
conn := &proxyConnection{
|
||||||
|
proxyID: proxyID,
|
||||||
|
address: clusterAddr,
|
||||||
|
sendChan: ch,
|
||||||
|
}
|
||||||
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
|
||||||
|
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||||
|
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||||
|
select {
|
||||||
|
case msg := <-ch:
|
||||||
|
return msg
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
const cluster = "proxy.example.com"
|
||||||
|
const numProxies = 3
|
||||||
|
|
||||||
|
channels := make([]chan *proto.ProxyMapping, numProxies)
|
||||||
|
for i := range numProxies {
|
||||||
|
id := "proxy-" + string(rune('a'+i))
|
||||||
|
channels[i] = registerFakeProxy(s, id, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Path: []*proto.PathMapping{
|
||||||
|
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
|
||||||
|
tokens := make([]string, numProxies)
|
||||||
|
for i, ch := range channels {
|
||||||
|
msg := drainChannel(ch)
|
||||||
|
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||||
|
assert.Equal(t, update.Domain, msg.Domain)
|
||||||
|
assert.Equal(t, update.Id, msg.Id)
|
||||||
|
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||||
|
tokens[i] = msg.AuthToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// All tokens must be unique
|
||||||
|
tokenSet := make(map[string]struct{})
|
||||||
|
for i, tok := range tokens {
|
||||||
|
_, exists := tokenSet[tok]
|
||||||
|
assert.False(t, exists, "proxy %d got duplicate token", i)
|
||||||
|
tokenSet[tok] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each token must be independently consumable
|
||||||
|
for i, tok := range tokens {
|
||||||
|
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
||||||
|
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
const cluster = "proxy.example.com"
|
||||||
|
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||||
|
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
|
||||||
|
msg1 := drainChannel(ch1)
|
||||||
|
msg2 := drainChannel(ch2)
|
||||||
|
require.NotNil(t, msg1)
|
||||||
|
require.NotNil(t, msg2)
|
||||||
|
|
||||||
|
// Delete operations should not generate tokens
|
||||||
|
assert.Empty(t, msg1.AuthToken)
|
||||||
|
assert.Empty(t, msg2.AuthToken)
|
||||||
|
|
||||||
|
// No tokens should have been created
|
||||||
|
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||||
|
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||||
|
defer tokenStore.Close()
|
||||||
|
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
tokenStore: tokenStore,
|
||||||
|
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||||
|
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||||
|
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||||
|
|
||||||
|
update := &proto.ProxyMapping{
|
||||||
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
|
Id: "service-1",
|
||||||
|
AccountId: "account-1",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SendServiceUpdate(update)
|
||||||
|
|
||||||
|
msg1 := drainChannel(ch1)
|
||||||
|
msg2 := drainChannel(ch2)
|
||||||
|
require.NotNil(t, msg1)
|
||||||
|
require.NotNil(t, msg2)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, msg1.AuthToken)
|
||||||
|
assert.NotEmpty(t, msg2.AuthToken)
|
||||||
|
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
||||||
|
|
||||||
|
// Both tokens should validate
|
||||||
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
||||||
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState creates a state using the same format as GetOIDCURL.
|
||||||
|
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||||
|
nonce := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(nonce)
|
||||||
|
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
||||||
|
|
||||||
|
payload := redirectURL + "|" + nonceB64
|
||||||
|
hmacSum := s.generateHMAC(payload)
|
||||||
|
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := "https://app.example.com/callback"
|
||||||
|
|
||||||
|
// Generate 100 states for the same redirect URL
|
||||||
|
states := make(map[string]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
state := generateState(s, redirectURL)
|
||||||
|
|
||||||
|
// State must have 3 parts: base64(url)|nonce|hmac
|
||||||
|
parts := strings.Split(state, "|")
|
||||||
|
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
||||||
|
|
||||||
|
// State must be unique
|
||||||
|
require.False(t, states[state], "state %d is a duplicate", i)
|
||||||
|
states[state] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old format had only 2 parts: base64(url)|hmac
|
||||||
|
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||||
|
|
||||||
|
_, _, err := s.ValidateState("base64url|hmac")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid state format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||||
|
s := &ProxyServiceServer{
|
||||||
|
oidcConfig: ProxyOIDCConfig{
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store with tampered HMAC
|
||||||
|
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||||
|
|
||||||
|
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid state signature")
|
||||||
|
}
|
||||||
@@ -17,13 +17,14 @@ import (
|
|||||||
pb "github.com/golang/protobuf/proto" // nolint
|
pb "github.com/golang/protobuf/proto" // nolint
|
||||||
"github.com/golang/protobuf/ptypes/timestamp"
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
@@ -76,8 +77,9 @@ type Server struct {
|
|||||||
|
|
||||||
oAuthConfigProvider idp.OAuthConfigProvider
|
oAuthConfigProvider idp.OAuthConfigProvider
|
||||||
|
|
||||||
syncSem atomic.Int32
|
syncSem atomic.Int32
|
||||||
syncLim int32
|
syncLimEnabled bool
|
||||||
|
syncLim int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@@ -107,6 +109,7 @@ func NewServer(
|
|||||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||||
|
|
||||||
syncLim := int32(defaultSyncLim)
|
syncLim := int32(defaultSyncLim)
|
||||||
|
syncLimEnabled := true
|
||||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -114,6 +117,9 @@ func NewServer(
|
|||||||
} else {
|
} else {
|
||||||
//nolint:gosec
|
//nolint:gosec
|
||||||
syncLim = int32(syncLimParsed)
|
syncLim = int32(syncLimParsed)
|
||||||
|
if syncLim < 0 {
|
||||||
|
syncLimEnabled = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +139,8 @@ func NewServer(
|
|||||||
|
|
||||||
loginFilter: newLoginFilter(),
|
loginFilter: newLoginFilter(),
|
||||||
|
|
||||||
syncLim: syncLim,
|
syncLim: syncLim,
|
||||||
|
syncLimEnabled: syncLimEnabled,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
|
|||||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
if s.syncSem.Load() >= s.syncLim {
|
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
|
||||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||||
}
|
}
|
||||||
s.syncSem.Add(1)
|
s.syncSem.Add(1)
|
||||||
@@ -293,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -304,6 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
|
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
|
||||||
@@ -396,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
// It implements a backpressure mechanism that sends the first update immediately,
|
||||||
|
// then debounces subsequent rapid updates, ensuring only the latest update is sent
|
||||||
|
// after a quiet period.
|
||||||
|
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||||
|
|
||||||
|
// Create a debouncer for this peer connection
|
||||||
|
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// condition when there are some updates
|
// condition when there are some updates
|
||||||
|
// todo set the updates channel size to 1
|
||||||
case update, open := <-updates:
|
case update, open := <-updates:
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||||
@@ -408,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
if !open {
|
if !open {
|
||||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
if debouncer.ProcessUpdate(update) {
|
||||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
// Send immediately (first update or after quiet period)
|
||||||
return err
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timer expired - quiet period reached, send pending updates if any
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
|
||||||
|
for _, pendingUpdate := range pendingUpdates {
|
||||||
|
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// condition when client <-> server connection has been terminated
|
// condition when client <-> server connection has been terminated
|
||||||
case <-srv.Context().Done():
|
case <-srv.Context().Done():
|
||||||
// happens when connection drops, e.g. client disconnects
|
// happens when connection drops, e.g. client disconnects
|
||||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return srv.Context().Err()
|
return srv.Context().Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -429,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
|
|
||||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||||
// then sends the encrypted message to the connected peer via the sync server.
|
// then sends the encrypted message to the connected peer via the sync server.
|
||||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
|
||||||
key, err := s.secretsManager.GetWGKey()
|
key, err := s.secretsManager.GetWGKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
@@ -446,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
|
|||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||||
@@ -478,11 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
|
||||||
|
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||||
@@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
|||||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||||
|
|||||||
103
management/internals/shared/grpc/update_debouncer.go
Normal file
103
management/internals/shared/grpc/update_debouncer.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateDebouncer implements a backpressure mechanism that:
|
||||||
|
// - Sends the first update immediately
|
||||||
|
// - Coalesces rapid subsequent network map updates (only latest matters)
|
||||||
|
// - Queues control/config updates (all must be delivered)
|
||||||
|
// - Preserves the order of messages (important for control configs between network maps)
|
||||||
|
// - Ensures pending updates are sent after a quiet period
|
||||||
|
type UpdateDebouncer struct {
|
||||||
|
debounceInterval time.Duration
|
||||||
|
timer *time.Timer
|
||||||
|
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
|
||||||
|
timerC <-chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpdateDebouncer creates a new debouncer with the specified interval
|
||||||
|
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
|
||||||
|
return &UpdateDebouncer{
|
||||||
|
debounceInterval: interval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
|
||||||
|
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
|
||||||
|
if d.timer == nil {
|
||||||
|
// No active debounce timer, signal to send immediately
|
||||||
|
// and start the debounce period
|
||||||
|
d.startTimer()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already in debounce period, accumulate this update preserving order
|
||||||
|
// Check if we should coalesce with the last pending update
|
||||||
|
if len(d.pendingUpdates) > 0 &&
|
||||||
|
update.MessageType == network_map.MessageTypeNetworkMap &&
|
||||||
|
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
|
||||||
|
// Replace the last network map with this one (coalesce consecutive network maps)
|
||||||
|
d.pendingUpdates[len(d.pendingUpdates)-1] = update
|
||||||
|
} else {
|
||||||
|
// Append to the queue (preserves order for control configs and non-consecutive network maps)
|
||||||
|
d.pendingUpdates = append(d.pendingUpdates, update)
|
||||||
|
}
|
||||||
|
d.resetTimer()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimerChannel returns the timer channel for select statements
|
||||||
|
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
|
||||||
|
if d.timer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return d.timerC
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingUpdates returns and clears all pending updates after timer expiration.
|
||||||
|
// Updates are returned in the order they were received, with consecutive network maps
|
||||||
|
// already coalesced to only the latest one.
|
||||||
|
// If there were pending updates, it restarts the timer to continue debouncing.
|
||||||
|
// If there were no pending updates, it clears the timer (true quiet period).
|
||||||
|
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
|
||||||
|
updates := d.pendingUpdates
|
||||||
|
d.pendingUpdates = nil
|
||||||
|
|
||||||
|
if len(updates) > 0 {
|
||||||
|
// There were pending updates, so updates are still coming rapidly
|
||||||
|
// Restart the timer to continue debouncing mode
|
||||||
|
if d.timer != nil {
|
||||||
|
d.timer.Reset(d.debounceInterval)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No pending updates means true quiet period - return to immediate mode
|
||||||
|
d.timer = nil
|
||||||
|
d.timerC = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the debouncer and cleans up resources
|
||||||
|
func (d *UpdateDebouncer) Stop() {
|
||||||
|
if d.timer != nil {
|
||||||
|
d.timer.Stop()
|
||||||
|
d.timer = nil
|
||||||
|
d.timerC = nil
|
||||||
|
}
|
||||||
|
d.pendingUpdates = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UpdateDebouncer) startTimer() {
|
||||||
|
d.timer = time.NewTimer(d.debounceInterval)
|
||||||
|
d.timerC = d.timer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *UpdateDebouncer) resetTimer() {
|
||||||
|
d.timer.Stop()
|
||||||
|
d.timer.Reset(d.debounceInterval)
|
||||||
|
}
|
||||||
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
587
management/internals/shared/grpc/update_debouncer_test.go
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldSend := debouncer.ProcessUpdate(update)
|
||||||
|
|
||||||
|
if !shouldSend {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if debouncer.TimerChannel() == nil {
|
||||||
|
t.Error("Timer should be started after first update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update should be sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update1) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rapid subsequent updates should be coalesced
|
||||||
|
if debouncer.ProcessUpdate(update2) {
|
||||||
|
t.Error("Second rapid update should not be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if debouncer.ProcessUpdate(update3) {
|
||||||
|
t.Error("Third rapid update should not be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update3 {
|
||||||
|
t.Error("Should get the last update (update3)")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Send second update within debounce period
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait for timer
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update2 {
|
||||||
|
t.Error("Should get the last update")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] == update1 {
|
||||||
|
t.Error("Should not get the first update")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Wait a bit, but not the full debounce period
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send second update - should reset timer
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait a bit more
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send third update - should reset timer again
|
||||||
|
debouncer.ProcessUpdate(update3)
|
||||||
|
|
||||||
|
// Now wait for the timer (should fire after last update's reset)
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != update3 {
|
||||||
|
t.Error("Should get the last update (update3)")
|
||||||
|
}
|
||||||
|
// Timer should be restarted since there was a pending update
|
||||||
|
if debouncer.TimerChannel() == nil {
|
||||||
|
t.Error("Timer should be restarted after sending pending update")
|
||||||
|
}
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update3 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(update1)
|
||||||
|
|
||||||
|
// Second update coalesced
|
||||||
|
debouncer.ProcessUpdate(update2)
|
||||||
|
|
||||||
|
// Wait for timer to expire
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
t.Fatal("Should have pending update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After sending pending update, timer is restarted, so next update is NOT immediate
|
||||||
|
if debouncer.ProcessUpdate(update3) {
|
||||||
|
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the restarted timer and verify update3 is pending
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
finalUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
||||||
|
t.Error("Should get update3 as pending")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired for restarted timer")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send update to start timer
|
||||||
|
debouncer.ProcessUpdate(update)
|
||||||
|
|
||||||
|
// Stop should clean up
|
||||||
|
debouncer.Stop()
|
||||||
|
|
||||||
|
// Multiple stops should be safe
|
||||||
|
debouncer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate high-frequency updates
|
||||||
|
var lastUpdate *network_map.UpdateMessage
|
||||||
|
sentImmediately := 0
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
lastUpdate = update
|
||||||
|
if debouncer.ProcessUpdate(update) {
|
||||||
|
sentImmediately++
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only first update should be sent immediately
|
||||||
|
if sentImmediately != 1 {
|
||||||
|
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0] != lastUpdate {
|
||||||
|
t.Error("Should get the very last update")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
||||||
|
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send first update
|
||||||
|
if !debouncer.ProcessUpdate(update) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for timer to expire with no additional updates (true quiet period)
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) != 0 {
|
||||||
|
t.Error("Should have no pending updates")
|
||||||
|
}
|
||||||
|
// After true quiet period, timer should be cleared
|
||||||
|
if debouncer.TimerChannel() != nil {
|
||||||
|
t.Error("Timer should be cleared after quiet period")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
updates := make([]*network_map.UpdateMessage, 5)
|
||||||
|
for i := range updates {
|
||||||
|
updates[i] = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(updates[0])
|
||||||
|
|
||||||
|
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
||||||
|
debouncer.ProcessUpdate(updates[1])
|
||||||
|
debouncer.ProcessUpdate(updates[2])
|
||||||
|
debouncer.ProcessUpdate(updates[3])
|
||||||
|
debouncer.ProcessUpdate(updates[4])
|
||||||
|
|
||||||
|
// Wait for debounce
|
||||||
|
<-debouncer.TimerChannel()
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) != 1 {
|
||||||
|
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
||||||
|
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
update1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
update2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update1) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for timer without sending any more updates (true quiet period)
|
||||||
|
<-debouncer.TimerChannel()
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
|
||||||
|
if len(pendingUpdates) != 0 {
|
||||||
|
t.Error("Should have no pending updates during quiet period")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After true quiet period, next update should be sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update2) {
|
||||||
|
t.Error("Update after true quiet period should be sent immediately")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate continuous high-frequency updates
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
update := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: uint64(i),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
// First one sent immediately
|
||||||
|
if !debouncer.ProcessUpdate(update) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// All others should be coalesced (not sent immediately)
|
||||||
|
if debouncer.ProcessUpdate(update) {
|
||||||
|
t.Errorf("Update %d should not be sent immediately", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit but send next update before debounce expires
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now wait for final debounce
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
if len(pendingUpdates) == 0 {
|
||||||
|
t.Fatal("Should have the last update pending")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
||||||
|
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
netmapUpdate := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
tokenUpdate1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
tokenUpdate2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate)
|
||||||
|
|
||||||
|
// Send multiple control config updates - they should all be queued
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate1)
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate2)
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get both control config updates
|
||||||
|
if len(pendingUpdates) != 2 {
|
||||||
|
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// Control configs should come first
|
||||||
|
if pendingUpdates[0] != tokenUpdate1 {
|
||||||
|
t.Error("First pending update should be tokenUpdate1")
|
||||||
|
}
|
||||||
|
if pendingUpdates[1] != tokenUpdate2 {
|
||||||
|
t.Error("Second pending update should be tokenUpdate2")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
netmapUpdate1 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
netmapUpdate2 := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
tokenUpdate := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First update sent immediately
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate1)
|
||||||
|
|
||||||
|
// Send token update and network map update
|
||||||
|
debouncer.ProcessUpdate(tokenUpdate)
|
||||||
|
debouncer.ProcessUpdate(netmapUpdate2)
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get 2 updates in order: token, then network map
|
||||||
|
if len(pendingUpdates) != 2 {
|
||||||
|
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// Token update should come first (preserves order)
|
||||||
|
if pendingUpdates[0] != tokenUpdate {
|
||||||
|
t.Error("First pending update should be tokenUpdate")
|
||||||
|
}
|
||||||
|
// Network map update should come second
|
||||||
|
if pendingUpdates[1] != netmapUpdate2 {
|
||||||
|
t.Error("Second pending update should be netmapUpdate2")
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
||||||
|
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||||
|
defer debouncer.Stop()
|
||||||
|
|
||||||
|
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
||||||
|
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
||||||
|
|
||||||
|
// Send first network map immediately
|
||||||
|
firstNetmap := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
if !debouncer.ProcessUpdate(firstNetmap) {
|
||||||
|
t.Error("First update should be sent immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 49 more network maps (will be coalesced to last one)
|
||||||
|
var lastNetmapBatch1 *network_map.UpdateMessage
|
||||||
|
for i := 1; i < 50; i++ {
|
||||||
|
lastNetmapBatch1 = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(lastNetmapBatch1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 1 control config
|
||||||
|
controlConfig := &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||||
|
MessageType: network_map.MessageTypeControlConfig,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(controlConfig)
|
||||||
|
|
||||||
|
// Send 50 more network maps (will be coalesced to last one)
|
||||||
|
var lastNetmapBatch2 *network_map.UpdateMessage
|
||||||
|
for i := 50; i < 100; i++ {
|
||||||
|
lastNetmapBatch2 = &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
debouncer.ProcessUpdate(lastNetmapBatch2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for debounce period
|
||||||
|
select {
|
||||||
|
case <-debouncer.TimerChannel():
|
||||||
|
pendingUpdates := debouncer.GetPendingUpdates()
|
||||||
|
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
||||||
|
if len(pendingUpdates) != 3 {
|
||||||
|
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
||||||
|
}
|
||||||
|
// First should be the last netmap from batch 1
|
||||||
|
if pendingUpdates[0] != lastNetmapBatch1 {
|
||||||
|
t.Error("First pending update should be last netmap from batch 1")
|
||||||
|
}
|
||||||
|
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
||||||
|
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
// Second should be the control config
|
||||||
|
if pendingUpdates[1] != controlConfig {
|
||||||
|
t.Error("Second pending update should be control config")
|
||||||
|
}
|
||||||
|
// Third should be the last netmap from batch 2
|
||||||
|
if pendingUpdates[2] != lastNetmapBatch2 {
|
||||||
|
t.Error("Third pending update should be last netmap from batch 2")
|
||||||
|
}
|
||||||
|
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
||||||
|
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timer should have fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user