mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
Compare commits
65 Commits
trigger-pr
...
deploy/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c049f6f09 | ||
|
|
740c726a78 | ||
|
|
3af287ebab | ||
|
|
d4d885d434 | ||
|
|
d212332f5d | ||
|
|
5ca1b64328 | ||
|
|
36752a8cbb | ||
|
|
f117fc7509 | ||
|
|
fc6b93ae59 | ||
|
|
564fa4ab04 | ||
|
|
0e11258e97 | ||
|
|
31ecf8f1f5 | ||
|
|
e2df1fb35e | ||
|
|
a6db88fbd2 | ||
|
|
942cd5dc72 | ||
|
|
4b5294e596 | ||
|
|
a322dce42a | ||
|
|
d1ead2265b | ||
|
|
bbca74476e | ||
|
|
318cf59d66 | ||
|
|
e9b2a6e808 | ||
|
|
2dbdb5c1a7 | ||
|
|
2cdab6d7b7 | ||
|
|
e49c0e8862 | ||
|
|
e7c84d0ead | ||
|
|
1c934cca64 | ||
|
|
4aff4a6424 | ||
|
|
1bd7190954 | ||
|
|
0146e39714 | ||
|
|
baed6e46ec | ||
|
|
0d1ffba75f | ||
|
|
1024d45698 | ||
|
|
e5d4947d60 | ||
|
|
cb9b39b950 | ||
|
|
68c481fa44 | ||
|
|
01a9cd4651 | ||
|
|
f53155562f | ||
|
|
edce11b34d | ||
|
|
841b2d26c6 | ||
|
|
d3eeb6d8ee | ||
|
|
7ebf37ef20 | ||
|
|
64b849c801 | ||
|
|
69d4b5d821 | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
2de1949018 | ||
|
|
fc88399c23 | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
1b96648d4d | ||
|
|
d2f9653cea | ||
|
|
194a986926 | ||
|
|
f7732557fa | ||
|
|
d488f58311 | ||
|
|
6fdc00ff41 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
*.key
|
||||||
|
*.crt
|
||||||
|
*.p12
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./signal/...
|
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
98
.github/workflows/golang-test-linux.yml
vendored
98
.github/workflows/golang-test-linux.yml
vendored
@@ -97,6 +97,16 @@ jobs:
|
|||||||
working-directory: relay
|
working-directory: relay
|
||||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
|
- name: Build combined
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build combined 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -144,7 +154,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +214,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,6 +271,53 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
|
test_proxy:
|
||||||
|
name: "Proxy / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -352,12 +409,19 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: docker login for root user
|
||||||
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: download mysql image
|
- name: download mysql image
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
@@ -440,15 +504,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
- name: download mysql image
|
- name: docker login for root user
|
||||||
if: matrix.store == 'mysql'
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
@@ -529,15 +596,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
- name: download mysql image
|
- name: docker login for root user
|
||||||
if: matrix.store == 'mysql'
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.0"
|
SIGN_PIPE_VER: "v0.1.1"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -160,7 +160,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Log in to the GitHub container registry
|
- name: Log in to the GitHub container registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
@@ -176,6 +176,7 @@ jobs:
|
|||||||
- name: Generate windows syso arm64
|
- name: Generate windows syso arm64
|
||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
@@ -185,6 +186,19 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
|
- name: Tag and push PR images (amd64 only)
|
||||||
|
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
run: |
|
||||||
|
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||||
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
|
IMG_NAME="${SRC%%:*}"
|
||||||
|
DST="${IMG_NAME}:${PR_TAG}"
|
||||||
|
echo "Tagging ${SRC} -> ${DST}"
|
||||||
|
docker tag "$SRC" "$DST"
|
||||||
|
docker push "$DST"
|
||||||
|
done
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
|
!proxy/web/dist/
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
181
.goreleaser.yaml
181
.goreleaser.yaml
@@ -106,6 +106,26 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-server
|
||||||
|
dir: combined
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=1
|
||||||
|
- >-
|
||||||
|
{{- if eq .Runtime.Goos "linux" }}
|
||||||
|
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||||
|
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
binary: netbird-server
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-upload
|
- id: netbird-upload
|
||||||
dir: upload-server
|
dir: upload-server
|
||||||
env: [CGO_ENABLED=0]
|
env: [CGO_ENABLED=0]
|
||||||
@@ -120,6 +140,20 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-proxy
|
||||||
|
dir: proxy/cmd/proxy
|
||||||
|
env: [CGO_ENABLED=0]
|
||||||
|
binary: netbird-proxy
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -520,6 +554,104 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
@@ -598,6 +730,18 @@ docker_manifests:
|
|||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird-server:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
@@ -675,6 +819,43 @@ docker_manifests:
|
|||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
|||||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||||
|
|
||||||
BSD 3-Clause License
|
BSD 3-Clause License
|
||||||
|
|||||||
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### Self-Host NetBird (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
package android
|
package android
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// EnvKeyNBForceRelay Exported for Android java client
|
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
|
||||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||||
|
|
||||||
|
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||||
|
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||||
|
|
||||||
|
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||||
|
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||||
)
|
)
|
||||||
|
|
||||||
// EnvList wraps a Go map for export to Java
|
// EnvList wraps a Go map for export to Java
|
||||||
|
|||||||
@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
}
|
}
|
||||||
defer authClient.Close()
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
err, isAuthError := authClient.Login(ctx, "", "")
|
return fmt.Errorf("check login required: %v", err)
|
||||||
if isAuthError {
|
|
||||||
needsLogin = true
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("login check failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
|
|||||||
@@ -31,6 +31,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
|
PeerStatusConnected = peer.StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -71,6 +79,8 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
|
WireguardPort *int
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -140,6 +150,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
WireguardPort: opts.WireguardPort,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -159,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
|
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +192,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
|
||||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create auth client: %w", err)
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
@@ -189,10 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
|
||||||
c.recorder = recorder
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -345,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
recorder := c.recorder
|
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if recorder == nil {
|
|
||||||
return peer.FullStatus{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -360,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetFullStatus(), nil
|
return c.recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nftRule.Handle == 0 {
|
if nftRule.Handle == 0 {
|
||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback ipset counter
|
r.rollbackRules(pair)
|
||||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||||
|
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||||
|
keys := []string{
|
||||||
|
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
rule, ok := r.rules[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
}
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||||
|
// counters will be off until the next successful removal or refresh cycle.
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback set counter
|
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||||
func (r *router) refreshRulesMap() error {
|
func (r *router) refreshRulesMap() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
newRules := make(map[string]*nftables.Rule)
|
||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list rules: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||||
|
// preserve existing entries for this chain since we can't verify their state
|
||||||
|
for k, v := range r.rules {
|
||||||
|
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||||
|
newRules[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
r.rules[string(rule.UserData)] = rule
|
newRules[string(rule.UserData)] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
r.rules = newRules
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var needsFlush bool
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
if dnatRule.Handle == 0 {
|
||||||
|
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(masqRule); err != nil {
|
if masqRule.Handle == 0 {
|
||||||
|
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if needsFlush {
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
rule, exists := r.rules[ruleID]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
return nil
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Add a real rule to the kernel
|
||||||
|
ruleKey, err := r.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
firewall.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&firewall.Port{Values: []uint16{80}},
|
||||||
|
firewall.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
|
staleKey := "stale-rule-that-does-not-exist"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||||
|
|
||||||
|
err = r.refreshRulesMap()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||||
|
|
||||||
|
realRule, ok := r.rules[ruleKey.ID()]
|
||||||
|
assert.True(t, ok, "real rule should still exist after refresh")
|
||||||
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0
|
||||||
|
staleKey := "stale-route-rule"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule should not return an error for stale handles
|
||||||
|
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||||
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
pair := firewall.RouterPair{
|
||||||
|
ID: "staletest",
|
||||||
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
rtr := manager.router
|
||||||
|
|
||||||
|
// First add succeeds
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Corrupt the handle to simulate stale state
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||||
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the same rule again should succeed despite stale handles
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||||
|
|
||||||
|
// Verify rules exist in kernel
|
||||||
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
found := 0
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||||
|
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||||
|
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||||
|
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||||
|
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||||
|
if t.tombstone.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && !conn.IsSupersededBy(flags) {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsTombstone() {
|
if !exists || conn.IsSupersededBy(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||||
|
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||||
|
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||||
|
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||||
|
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||||
|
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and gracefully close a connection (server-initiated close)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Server sends FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Client sends FIN-ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Server sends final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Connection should be tombstoned
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn, "old connection should still be in map")
|
||||||
|
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||||
|
|
||||||
|
// Now reuse the same port for a new connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// The old tombstoned entry should be replaced with a new one
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
// SYN-ACK for the new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
|
||||||
|
// Data transfer should work
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||||
|
require.True(t, valid, "data should be allowed on new connection")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and RST a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||||
|
|
||||||
|
// Reuse the same port
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := srcIP
|
||||||
|
serverIP := dstIP
|
||||||
|
clientPort := srcPort
|
||||||
|
serverPort := dstPort
|
||||||
|
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||||
|
|
||||||
|
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Server-initiated close to reach Closed/tombstoned:
|
||||||
|
// Server FIN (opposite dir) → CloseWait
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
// Client FIN-ACK (same dir as conn) → LastAck
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// New inbound connection on same ports
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
|
||||||
|
// Complete handshake: server SYN-ACK, then client ACK
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// Late ACK should be rejected (tombstoned)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||||
|
|
||||||
|
// Late outbound ACK should not create a new connection (not a SYN)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Active close: client (outbound initiator) sends FIN first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Server ACKs the FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Server sends its own FIN
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||||
|
|
||||||
|
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||||
|
|
||||||
|
// SYN-ACK for new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish outbound connection and close via active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||||
|
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||||
|
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||||
|
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||||
|
|
||||||
|
// Simulate what the filter does next: TrackInbound via the normal path
|
||||||
|
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||||
|
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||||
|
newConn := tracker.connections[invertedKey]
|
||||||
|
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,11 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
@@ -24,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -89,6 +93,7 @@ type Manager struct {
|
|||||||
incomingDenyRules map[netip.Addr]RuleSet
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
|||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||||
|
return existingRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: string(ruleKey),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.routeRulesMap[ruleKey] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleKey := nbid.RuleID(rule.ID())
|
||||||
|
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == string(ruleKey)
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
delete(m.routeRulesMap, ruleKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
|
// Must be called with m.mutex held.
|
||||||
|
func (m *Manager) resetState() {
|
||||||
|
maps.Clear(m.outgoingRules)
|
||||||
|
maps.Clear(m.incomingDenyRules)
|
||||||
|
maps.Clear(m.incomingRules)
|
||||||
|
maps.Clear(m.routeRulesMap)
|
||||||
|
m.routeRules = m.routeRules[:0]
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
|
|||||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||||
|
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||||
|
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.1.0/24"),
|
||||||
|
netip.MustParsePrefix("100.64.2.0/24"),
|
||||||
|
}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add rule first time
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
// Add the same rule again
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
// These should be the same (idempotent) like nftables/iptables implementations
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||||
|
// different parameters get distinct IDs.
|
||||||
|
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
|
// Add first rule
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add different rule (different destination)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-2"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Different rules should have different IDs")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||||
|
// rule during a network map update does not disrupt existing traffic.
|
||||||
|
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
|
// Re-add same rule (simulates network map update)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||||
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
|
if rule1.ID() != rule2.ID() {
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.True(t, passAfter,
|
||||||
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
|
// returns the same rule without duplicating.
|
||||||
|
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call blockInvalidRouted directly multiple times
|
||||||
|
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule3)
|
||||||
|
|
||||||
|
// All should return the same rule
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||||
|
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||||
|
|
||||||
|
// Should have exactly 1 route rule
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||||
|
|
||||||
|
// Verify the rule blocks traffic to the WG network
|
||||||
|
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||||
|
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||||
|
// EnableRouting multiple times (as happens on each route update) does not
|
||||||
|
// accumulate duplicate block rules in the routeRules slice.
|
||||||
|
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount,
|
||||||
|
"Repeated EnableRouting should not accumulate block rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||||
|
// rule multiple times does not create duplicate entries.
|
||||||
|
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Simulate 5 network map updates with the same route rule
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||||
|
// after adding it multiple times works correctly.
|
||||||
|
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add same rule twice
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
|
// Delete using first reference
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify traffic no longer passes
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestManager(t *testing.T) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
|
}
|
||||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
|
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||||
|
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Add multiple deny rules for different ports
|
||||||
|
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
|
// Delete the first deny rule
|
||||||
|
err = m.DeletePeerRule(rule1[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount = len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
|
// Delete the second deny rule
|
||||||
|
err = m.DeletePeerRule(rule2[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, exists := m.incomingDenyRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
|
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||||
|
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
// Add a deny rule
|
||||||
|
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add an allow rule
|
||||||
|
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete them (simulating ACL manager cleanup)
|
||||||
|
for _, r := range rules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
for _, r := range allowRules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
|
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||||
|
// IP are stored in separate maps and don't interfere with each other.
|
||||||
|
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
|
// Add allow rule for port 80
|
||||||
|
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add deny rule for port 22
|
||||||
|
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
m.mutex.RLock()
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
|
// Delete allow rule should not affect deny rule
|
||||||
|
err = m.DeletePeerRule(allowRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
|
// Delete deny rule
|
||||||
|
err = m.DeletePeerRule(denyRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, denyExists := m.incomingDenyRules[addr]
|
||||||
|
_, allowExists := m.incomingRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
|
require.False(t, allowExists, "Allow rules should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,9 +18,18 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
logChannelSize = 1000
|
defaultLogChanSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getLogChannelSize() int {
|
||||||
|
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultLogChanSize
|
||||||
|
}
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -69,7 +80,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, logChannelSize),
|
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
|||||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||||
} else {
|
} else {
|
||||||
// Fallback for other lengths
|
// Fallback for other lengths
|
||||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
|||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDeviceFilter constructor function
|
// newDeviceFilter constructor function
|
||||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying tun device exactly once.
|
||||||
|
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||||
|
// and multiple code paths can trigger Close on the same device.
|
||||||
|
func (d *FilteredDevice) Close() error {
|
||||||
|
var err error
|
||||||
|
d.closeOnce.Do(func() {
|
||||||
|
err = d.Device.Close()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Read wraps read method with filtering feature
|
// Read wraps read method with filtering feature
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
log.Debugf("failed to close tun device: %v", cErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nbnetstack.IsEnabled() {
|
||||||
|
return errors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.waitUntilRemoved(); err != nil {
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
if err := w.Destroy(); err != nil {
|
if err := w.Destroy(); err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nsTunDev, tunNet, nil
|
return t.tundev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
|||||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||||
|
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||||
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||||
|
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||||
|
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||||
|
// up when they're removed from the network map in a subsequent update.
|
||||||
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: add deny and accept rules
|
||||||
|
networkMap1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap1, false)
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||||
|
|
||||||
|
// Second update: remove the deny rule, keep only accept
|
||||||
|
networkMap2 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap2, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Should have 1 rule after removing deny rule")
|
||||||
|
|
||||||
|
// Third update: remove all rules
|
||||||
|
networkMap3 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{},
|
||||||
|
FirewallRulesIsEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap3, false)
|
||||||
|
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||||
|
"Should have 0 rules after removing all rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||||
|
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||||
|
// one added without leaking.
|
||||||
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: accept rule
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
|
// Second update: change to deny (same IP/port/proto, different action)
|
||||||
|
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Changing action should result in exactly 1 rule, not 2")
|
||||||
|
}
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||||
|
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||||
globalIPv4State = "State:/Network/Global/IPv4"
|
globalIPv4State = "State:/Network/Global/IPv4"
|
||||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||||
@@ -35,6 +38,14 @@ const (
|
|||||||
searchSuffix = "Search"
|
searchSuffix = "Search"
|
||||||
matchSuffix = "Match"
|
matchSuffix = "Match"
|
||||||
localSuffix = "Local"
|
localSuffix = "Local"
|
||||||
|
|
||||||
|
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||||
|
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||||
|
maxDomainsPerResolverEntry = 50
|
||||||
|
|
||||||
|
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||||
|
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||||
|
maxDomainBytesPerResolverEntry = 1500
|
||||||
)
|
)
|
||||||
|
|
||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
@@ -84,28 +95,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||||
var err error
|
log.Warnf("failed to remove old match keys: %v", err)
|
||||||
if len(matchDomains) != 0 {
|
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("removing match domains from the system")
|
|
||||||
err = s.removeKeyFromSystemConfig(matchKey)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if len(matchDomains) != 0 {
|
||||||
return fmt.Errorf("add match domains: %w", err)
|
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||||
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.updateState(stateManager)
|
s.updateState(stateManager)
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||||
if len(searchDomains) != 0 {
|
log.Warnf("failed to remove old search keys: %v", err)
|
||||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("removing search domains from the system")
|
|
||||||
err = s.removeKeyFromSystemConfig(searchKey)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if len(searchDomains) != 0 {
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||||
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.updateState(stateManager)
|
s.updateState(stateManager)
|
||||||
|
|
||||||
@@ -149,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
|
|
||||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||||
if len(s.createdKeys) == 0 {
|
if len(s.createdKeys) == 0 {
|
||||||
// return defaults for startup calls
|
return s.discoverExistingKeys()
|
||||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := make([]string, 0, len(s.createdKeys))
|
keys := make([]string, 0, len(s.createdKeys))
|
||||||
@@ -160,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||||
|
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||||
|
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||||
|
dnsKeys, err := getSystemDNSKeys()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get system DNS keys: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
|
||||||
|
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||||
|
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||||
|
if strings.Contains(dnsKeys, key) {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||||
|
if !strings.Contains(dnsKeys, key) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemDNSKeys gets all DNS keys
|
||||||
|
func getSystemDNSKeys() (string, error) {
|
||||||
|
command := "list .*DNS\nquit\n"
|
||||||
|
out, err := runSystemConfigCommand(command)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||||
line := buildRemoveKeyOperation(key)
|
line := buildRemoveKeyOperation(key)
|
||||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||||
@@ -184,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.addSearchDomains(
|
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||||
localKey,
|
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
return fmt.Errorf("add local dns state: %w", err)
|
||||||
); err != nil {
|
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
|
||||||
}
|
}
|
||||||
|
s.createdKeys[localKey] = struct{}{}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -280,28 +325,77 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
|||||||
return slices.Clone(s.origNameservers)
|
return slices.Clone(s.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||||
if err != nil {
|
if len(domains) == 0 {
|
||||||
return fmt.Errorf("add dns state: %w", err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
var batches [][]string
|
||||||
|
var current []string
|
||||||
|
currentBytes := 0
|
||||||
|
|
||||||
s.createdKeys[key] = struct{}{}
|
for _, d := range domains {
|
||||||
|
domainLen := len(d)
|
||||||
|
newBytes := currentBytes + domainLen
|
||||||
|
if currentBytes > 0 {
|
||||||
|
newBytes++ // space separator
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||||
|
batches = append(batches, current)
|
||||||
|
current = nil
|
||||||
|
currentBytes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
current = append(current, d)
|
||||||
|
if currentBytes > 0 {
|
||||||
|
currentBytes += 1 + domainLen
|
||||||
|
} else {
|
||||||
|
currentBytes = domainLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(current) > 0 {
|
||||||
|
batches = append(batches, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
return batches
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
// removeKeysContaining removes all created keys that contain the given substring.
|
||||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||||
if err != nil {
|
var toRemove []string
|
||||||
return fmt.Errorf("add dns state: %w", err)
|
for key := range s.createdKeys {
|
||||||
|
if strings.Contains(key, suffix) {
|
||||||
|
toRemove = append(toRemove, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var multiErr *multierror.Error
|
||||||
|
for _, key := range toRemove {
|
||||||
|
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||||
|
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(multiErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||||
|
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||||
|
batches := splitDomainsIntoBatches(domains)
|
||||||
|
|
||||||
|
for i, batch := range batches {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||||
|
domainsStr := strings.Join(batch, " ")
|
||||||
|
|
||||||
|
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||||
|
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.createdKeys[key] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||||
|
|
||||||
s.createdKeys[key] = struct{}{}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -364,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error {
|
|||||||
if out, err := cmd.CombinedOutput(); err != nil {
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("flushed DNS cache")
|
log.Info("flushed DNS cache")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -49,17 +52,22 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, sm.PersistState(context.Background()))
|
require.NoError(t, sm.PersistState(context.Background()))
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
// Collect all created keys for cleanup verification
|
||||||
|
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||||
|
for key := range configurator.createdKeys {
|
||||||
|
createdKeys = append(createdKeys, key)
|
||||||
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
_ = removeTestDNSKey(key)
|
_ = removeTestDNSKey(key)
|
||||||
}
|
}
|
||||||
|
_ = removeTestDNSKey(localKey)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
exists, err := checkDNSKeyExists(key)
|
exists, err := checkDNSKeyExists(key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if exists {
|
if exists {
|
||||||
@@ -83,13 +91,223 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|||||||
err = shutdownState.Cleanup()
|
err = shutdownState.Cleanup()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
exists, err := checkDNSKeyExists(key)
|
exists, err := checkDNSKeyExists(key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||||
|
func generateShortDomains(count int) []string {
|
||||||
|
domains := make([]string, 0, count)
|
||||||
|
for i := range count {
|
||||||
|
label := ""
|
||||||
|
n := i
|
||||||
|
for {
|
||||||
|
label = string(rune('a'+n%26)) + label
|
||||||
|
n = n/26 - 1
|
||||||
|
if n < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
domains = append(domains, label+".com")
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||||
|
func generateLongDomains(count int) []string {
|
||||||
|
domains := make([]string, 0, count)
|
||||||
|
for i := range count {
|
||||||
|
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||||
|
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||||
|
out, err := cmd.Output()
|
||||||
|
require.NoError(t, err, "scutil show should succeed")
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
inArray := false
|
||||||
|
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||||
|
inArray = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inArray {
|
||||||
|
if line == "}" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// lines look like: "0 : a.com"
|
||||||
|
parts := strings.SplitN(line, " : ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
domains = append(domains, parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NoError(t, scanner.Err())
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domains []string
|
||||||
|
expectedCount int
|
||||||
|
checkAllPresent bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
domains: nil,
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "under_limit",
|
||||||
|
domains: generateShortDomains(10),
|
||||||
|
expectedCount: 1,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "at_element_limit",
|
||||||
|
domains: generateShortDomains(50),
|
||||||
|
expectedCount: 1,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "over_element_limit",
|
||||||
|
domains: generateShortDomains(51),
|
||||||
|
expectedCount: 2,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "triple_element_limit",
|
||||||
|
domains: generateShortDomains(150),
|
||||||
|
expectedCount: 3,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long_domains_hit_byte_limit",
|
||||||
|
domains: generateLongDomains(50),
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_short_domains",
|
||||||
|
domains: generateShortDomains(500),
|
||||||
|
expectedCount: 10,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_long_domains",
|
||||||
|
domains: generateLongDomains(500),
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
batches := splitDomainsIntoBatches(tc.domains)
|
||||||
|
|
||||||
|
if tc.expectedCount > 0 {
|
||||||
|
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify each batch respects limits
|
||||||
|
for i, batch := range batches {
|
||||||
|
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||||
|
"batch %d exceeds element limit", i)
|
||||||
|
|
||||||
|
totalBytes := 0
|
||||||
|
for j, d := range batch {
|
||||||
|
if j > 0 {
|
||||||
|
totalBytes++
|
||||||
|
}
|
||||||
|
totalBytes += len(d)
|
||||||
|
}
|
||||||
|
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||||
|
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.checkAllPresent {
|
||||||
|
var all []string
|
||||||
|
for _, batch := range batches {
|
||||||
|
all = append(all, batch...)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||||
|
// and verifies all domains are readable across multiple scutil keys.
|
||||||
|
func TestMatchDomainBatching(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping scutil integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
count int
|
||||||
|
generator func(int) []string
|
||||||
|
}{
|
||||||
|
{"short_10", 10, generateShortDomains},
|
||||||
|
{"short_50", 50, generateShortDomains},
|
||||||
|
{"short_100", 100, generateShortDomains},
|
||||||
|
{"short_200", 200, generateShortDomains},
|
||||||
|
{"short_500", 500, generateShortDomains},
|
||||||
|
{"long_10", 10, generateLongDomains},
|
||||||
|
{"long_50", 50, generateLongDomains},
|
||||||
|
{"long_100", 100, generateLongDomains},
|
||||||
|
{"long_200", 200, generateLongDomains},
|
||||||
|
{"long_500", 500, generateLongDomains},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for key := range configurator.createdKeys {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
domains := tc.generator(tc.count)
|
||||||
|
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
batches := splitDomainsIntoBatches(domains)
|
||||||
|
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||||
|
|
||||||
|
// Read back all domains from all batched keys
|
||||||
|
var got []string
|
||||||
|
for i := range batches {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, exists, "key %s should exist", key)
|
||||||
|
|
||||||
|
got = append(got, readDomainsFromKey(t, key)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||||
|
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||||
|
assert.Equal(t, domains, got, "domains should match in order")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkDNSKeyExists(key string) (bool, error) {
|
func checkDNSKeyExists(key string) (bool, error) {
|
||||||
cmd := exec.Command(scutilPath)
|
cmd := exec.Command(scutilPath)
|
||||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||||
@@ -158,15 +376,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
|
|||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
_ = sm.Stop(context.Background())
|
_ = sm.Stop(context.Background())
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for key := range configurator.createdKeys {
|
||||||
_ = removeTestDNSKey(key)
|
_ = removeTestDNSKey(key)
|
||||||
}
|
}
|
||||||
|
// Also clean up old-format keys and local key in case they exist
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||||
}
|
}
|
||||||
|
|
||||||
return configurator, sm, cleanup
|
return configurator, sm, cleanup
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ const (
|
|||||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||||
|
|
||||||
|
nrptMaxDomainsPerRule = 50
|
||||||
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
@@ -198,10 +200,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||||
|
// Update count even on error to ensure cleanup covers partially created rules
|
||||||
|
r.nrptEntryCount = count
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add dns match policy: %w", err)
|
return fmt.Errorf("add dns match policy: %w", err)
|
||||||
}
|
}
|
||||||
r.nrptEntryCount = count
|
|
||||||
} else {
|
} else {
|
||||||
r.nrptEntryCount = 0
|
r.nrptEntryCount = 0
|
||||||
}
|
}
|
||||||
@@ -239,23 +242,33 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
|||||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
for i, domain := range domains {
|
|
||||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
|
||||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
|
||||||
|
|
||||||
singleDomain := []string{domain}
|
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||||
|
ruleIndex := 0
|
||||||
|
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||||
|
end := i + nrptMaxDomainsPerRule
|
||||||
|
if end > len(domains) {
|
||||||
|
end = len(domains)
|
||||||
|
}
|
||||||
|
batchDomains := domains[i:end]
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||||
|
|
||||||
|
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||||
|
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment immediately so the caller's cleanup path knows about this rule
|
||||||
|
ruleIndex++
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
|
||||||
return len(domains), nil
|
return ruleIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||||
// when the number of match domains decreases between configuration changes.
|
// when the number of match domains decreases between configuration changes.
|
||||||
|
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
|
||||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping registry integration test in short mode")
|
t.Skip("skipping registry integration test in short mode")
|
||||||
@@ -37,51 +38,60 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
|||||||
gpo: false,
|
gpo: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
config5 := HostDNSConfig{
|
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||||
ServerIP: testIP,
|
domains125 := make([]DomainConfig, 125)
|
||||||
Domains: []DomainConfig{
|
for i := 0; i < 125; i++ {
|
||||||
{Domain: "domain1.com", MatchOnly: true},
|
domains125[i] = DomainConfig{
|
||||||
{Domain: "domain2.com", MatchOnly: true},
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
{Domain: "domain3.com", MatchOnly: true},
|
MatchOnly: true,
|
||||||
{Domain: "domain4.com", MatchOnly: true},
|
}
|
||||||
{Domain: "domain5.com", MatchOnly: true},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cfg.applyDNSConfig(config5, nil)
|
config125 := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: domains125,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cfg.applyDNSConfig(config125, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify all 5 entries exist
|
// Verify 3 NRPT rules exist
|
||||||
for i := 0; i < 5; i++ {
|
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
config2 := HostDNSConfig{
|
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
|
||||||
|
domains75 := make([]DomainConfig, 75)
|
||||||
|
for i := 0; i < 75; i++ {
|
||||||
|
domains75[i] = DomainConfig{
|
||||||
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
|
MatchOnly: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config75 := HostDNSConfig{
|
||||||
ServerIP: testIP,
|
ServerIP: testIP,
|
||||||
Domains: []DomainConfig{
|
Domains: domains75,
|
||||||
{Domain: "domain1.com", MatchOnly: true},
|
|
||||||
{Domain: "domain2.com", MatchOnly: true},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cfg.applyDNSConfig(config2, nil)
|
err = cfg.applyDNSConfig(config75, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify first 2 entries exist
|
// Verify first 2 NRPT rules exist
|
||||||
|
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify entries 2-4 are cleaned up
|
// Verify rule 2 is cleaned up
|
||||||
for i := 2; i < 5; i++ {
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
|
||||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
|
||||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func registryKeyExists(path string) (bool, error) {
|
func registryKeyExists(path string) (bool, error) {
|
||||||
@@ -97,6 +107,106 @@ func registryKeyExists(path string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cleanupRegistryKeys(*testing.T) {
|
func cleanupRegistryKeys(*testing.T) {
|
||||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
// Clean up more entries to account for batching tests with many domains
|
||||||
|
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||||
_ = cfg.removeDNSMatchPolicies()
|
_ = cfg.removeDNSMatchPolicies()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
|
||||||
|
func TestNRPTDomainBatching(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping registry integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cleanupRegistryKeys(t)
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
testIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||||
|
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||||
|
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||||
|
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||||
|
require.NoError(t, err, "Should create test interface registry key")
|
||||||
|
testKey.Close()
|
||||||
|
defer func() {
|
||||||
|
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := ®istryConfigurator{
|
||||||
|
guid: testGUID,
|
||||||
|
gpo: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
domainCount int
|
||||||
|
expectedRuleCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Less than 50 domains (single rule)",
|
||||||
|
domainCount: 30,
|
||||||
|
expectedRuleCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Exactly 50 domains (single rule)",
|
||||||
|
domainCount: 50,
|
||||||
|
expectedRuleCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "51 domains (two rules)",
|
||||||
|
domainCount: 51,
|
||||||
|
expectedRuleCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100 domains (two rules)",
|
||||||
|
domainCount: 100,
|
||||||
|
expectedRuleCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "125 domains (three rules: 50+50+25)",
|
||||||
|
domainCount: 125,
|
||||||
|
expectedRuleCount: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Clean up before each subtest
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
// Generate domains
|
||||||
|
domains := make([]DomainConfig, tc.domainCount)
|
||||||
|
for i := 0; i < tc.domainCount; i++ {
|
||||||
|
domains[i] = DomainConfig{
|
||||||
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
|
MatchOnly: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: domains,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.applyDNSConfig(config, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that exactly expectedRuleCount rules were created
|
||||||
|
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||||
|
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||||
|
|
||||||
|
// Verify all expected rules exist
|
||||||
|
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no extra rules were created
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -84,3 +84,18 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
|||||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||||
|
func (m *MockServer) BeginBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndBatch mock implementation of EndBatch from Server interface
|
||||||
|
func (m *MockServer) EndBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelBatch mock implementation of CancelBatch from Server interface
|
||||||
|
func (m *MockServer) CancelBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -27,6 +29,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -41,6 +45,9 @@ type IosDnsManager interface {
|
|||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||||
DeregisterHandler(domains domain.List, priority int)
|
DeregisterHandler(domains domain.List, priority int)
|
||||||
|
BeginBatch()
|
||||||
|
EndBatch()
|
||||||
|
CancelBatch()
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() netip.Addr
|
DnsIP() netip.Addr
|
||||||
@@ -83,6 +90,7 @@ type DefaultServer struct {
|
|||||||
currentConfigHash uint64
|
currentConfigHash uint64
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
extraDomains map[domain.Domain]int
|
extraDomains map[domain.Domain]int
|
||||||
|
batchMode bool
|
||||||
|
|
||||||
mgmtCacheResolver *mgmt.Resolver
|
mgmtCacheResolver *mgmt.Resolver
|
||||||
|
|
||||||
@@ -230,7 +238,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
|||||||
// convert to zone with simple ref counter
|
// convert to zone with simple ref counter
|
||||||
s.extraDomains[toZone(domain)]++
|
s.extraDomains[toZone(domain)]++
|
||||||
}
|
}
|
||||||
s.applyHostConfig()
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
@@ -259,9 +269,41 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
|||||||
delete(s.extraDomains, zone)
|
delete(s.extraDomains, zone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||||
|
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||||
|
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||||
|
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||||
|
func (s *DefaultServer) BeginBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Debugf("DNS batch mode enabled")
|
||||||
|
s.batchMode = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||||
|
func (s *DefaultServer) EndBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Debugf("DNS batch mode disabled, applying accumulated changes")
|
||||||
|
s.batchMode = false
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelBatch cancels batch mode without applying accumulated changes.
|
||||||
|
// This is useful when operations fail partway through and you want to
|
||||||
|
// discard partial state rather than applying it.
|
||||||
|
func (s *DefaultServer) CancelBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
|
||||||
|
s.batchMode = false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||||
|
|
||||||
@@ -439,6 +481,17 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
// ProbeAvailability tests each upstream group's servers for availability
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
// and deactivates the group if no server responds
|
// and deactivates the group if no server responds
|
||||||
func (s *DefaultServer) ProbeAvailability() {
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
|
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||||
|
skipProbe, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||||
|
}
|
||||||
|
if skipProbe {
|
||||||
|
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, mux := range s.dnsMuxMap {
|
for _, mux := range s.dnsMuxMap {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -508,6 +561,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always apply host config for management updates, regardless of batch mode
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
s.shutdownWg.Add(1)
|
s.shutdownWg.Add(1)
|
||||||
@@ -872,6 +926,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always apply host config when nameserver goes down, regardless of batch mode
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -907,6 +962,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always apply host config when nameserver reactivates, regardless of batch mode
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
|
|||||||
t.Errorf("invalid dns server instance: %s", err)
|
t.Errorf("invalid dns server instance: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if srvB != srv {
|
mockSrvB, ok := srvB.(*MockServer)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("returned server is not a MockServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mockSrvB != srv {
|
||||||
t.Errorf("mismatch dns instances")
|
t.Errorf("mismatch dns instances")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
qname := strings.ToLower(question.Name)
|
||||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
// query doesn't match any configured domain
|
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||||
f.cache.set(domain, question.Qtype, result.IPs)
|
f.cache.set(qname, question.Qtype, result.IPs)
|
||||||
|
|
||||||
return resp
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
|
type udpResponseWriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
query *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||||
|
opt := u.query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ResponseWriter.WriteMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt := query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
// client advertised a larger EDNS0 buffer
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if our response is too big, truncate and set the TC bit
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
if resp != nil {
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
}
|
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else if resp != nil {
|
} else {
|
||||||
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2, "expected response to be written")
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2)
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
var writtenResp *dns.Msg
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writtenResp = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "Expected response to be written")
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||||
if resp != nil && writtenResp == nil {
|
|
||||||
writtenResp = resp
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
writeCalled := false
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writeCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
@@ -543,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
wgIfaceName := e.wgInterface.Name()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.triggerClientRestart()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -828,6 +830,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.Infof("sync finished in %s", time.Since(started))
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1017,7 +1023,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(state)
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
@@ -1556,8 +1562,10 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
// connect to a stream of messages coming from the signal server
|
// connect to a stream of messages coming from the signal server
|
||||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||||
|
start := time.Now()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
gotLock := time.Since(start)
|
||||||
|
|
||||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
if e.ctx.Err() != nil {
|
if e.ctx.Err() != nil {
|
||||||
@@ -1581,6 +1589,8 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||||
|
|
||||||
if msg.Body.Type == sProto.Body_OFFER {
|
if msg.Body.Type == sProto.Body_OFFER {
|
||||||
conn.OnRemoteOffer(*offerAnswer)
|
conn.OnRemoteOffer(*offerAnswer)
|
||||||
} else {
|
} else {
|
||||||
@@ -1918,7 +1928,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor {
|
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
|
|
||||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerInfo) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
|||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configMgr := sshconfig.New()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindListener is only used on Windows and JS platforms:
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
// - JS: Cannot listen to UDP sockets
|
// - JS: Cannot listen to UDP sockets
|
||||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
// gateway points to, preventing them from reaching the loopback interface.
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
// BindListener bypasses this by passing data directly through the bind.
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,11 +37,6 @@ func New() *NetworkMonitor {
|
|||||||
|
|
||||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
if netstack.IsEnabled() {
|
|
||||||
log.Debugf("Network monitor: skipping in netstack mode")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nw.mu.Lock()
|
nw.mu.Lock()
|
||||||
if nw.cancel != nil {
|
if nw.cancel != nil {
|
||||||
nw.mu.Unlock()
|
nw.mu.Unlock()
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onICEStateDisconnected() {
|
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -430,14 +430,18 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
conn.dumpState.SwitchToRelay()
|
conn.dumpState.SwitchToRelay()
|
||||||
|
if sessionChanged {
|
||||||
|
conn.resetEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo consider to move after the ConfigureWGEndpoint
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.wgProxyRelay.Work()
|
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
} else {
|
} else {
|
||||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
@@ -499,20 +503,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
wgProxy.Work()
|
controller := isController(conn.config)
|
||||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
|
||||||
|
|
||||||
|
if controller {
|
||||||
|
wgProxy.Work()
|
||||||
|
}
|
||||||
conn.enableWgWatcherIfNeeded()
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
|
||||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !controller {
|
||||||
wgConfigWorkaround()
|
wgProxy.Work()
|
||||||
|
}
|
||||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
conn.statusRelay.SetConnected()
|
conn.statusRelay.SetConnected()
|
||||||
@@ -757,6 +763,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
|||||||
return wgProxy, nil
|
return wgProxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) resetEndpoint() {
|
||||||
|
if !isController(conn.config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.Log.Infof("reset wg endpoint")
|
||||||
|
conn.wgWatcher.Reset()
|
||||||
|
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||||
|
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) isReadyToUpgrade() bool {
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||||
}
|
}
|
||||||
@@ -862,9 +879,3 @@ func isController(config ConnConfig) bool {
|
|||||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||||
return remoteRosenpassPubKey != nil
|
return remoteRosenpassPubKey != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
|
||||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
|
||||||
func wgConfigWorkaround() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -34,28 +34,27 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
|
||||||
// The initiator immediately configures the endpoint, while the non-initiator
|
|
||||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
|
||||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
defer e.mu.Unlock()
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
if e.initiator {
|
if e.initiator {
|
||||||
e.log.Debugf("configure up WireGuard as initiatr")
|
e.log.Debugf("configure up WireGuard as initiator")
|
||||||
return e.updateWireGuardPeer(addr, presharedKey)
|
return e.configureAsInitiator(addr, presharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.log.Debugf("configure up WireGuard as responder")
|
||||||
|
return e.configureAsResponder(addr, presharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
// prevent to run new update while cancel the previous update
|
// prevent to run new update while cancel the previous update
|
||||||
e.waitForCloseTheDelayedUpdate()
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
|
||||||
var ctx context.Context
|
return e.updateWireGuardPeer(addr, presharedKey)
|
||||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
|
||||||
e.updateWg.Add(1)
|
|
||||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
|
||||||
|
|
||||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
|
||||||
return e.updateWireGuardPeer(nil, presharedKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||||
@@ -66,6 +65,38 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
|||||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
// prevent to run new update while cancel the previous update
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
|
||||||
|
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||||
|
e.updateWg.Add(1)
|
||||||
|
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||||
|
|
||||||
|
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||||
if e.cancelFunc == nil {
|
if e.cancelFunc == nil {
|
||||||
return
|
return
|
||||||
@@ -101,3 +132,9 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
|
|||||||
presharedKey,
|
presharedKey,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||||
|
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||||
|
func wgConfigWorkaround() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ThreadSafeAgent) Close() error {
|
|
||||||
var err error
|
|
||||||
a.once.Do(func() {
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- a.Agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if agent == nil {
|
||||||
|
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||||
|
}
|
||||||
|
|
||||||
return &ThreadSafeAgent{Agent: agent}, nil
|
return &ThreadSafeAgent{Agent: agent}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
// Defensive check to prevent nil pointer dereference
|
||||||
|
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||||
|
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||||
|
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||||
|
agent := a.Agent
|
||||||
|
if agent == nil {
|
||||||
|
log.Warnf("ICE agent is nil during close, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateICECredentials() (string, string, error) {
|
func GenerateICECredentials() (string, string, error) {
|
||||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ type WGWatcher struct {
|
|||||||
|
|
||||||
enabled bool
|
enabled bool
|
||||||
muEnabled sync.RWMutex
|
muEnabled sync.RWMutex
|
||||||
|
|
||||||
|
resetCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||||
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
|||||||
wgIfaceStater: wgIfaceStater,
|
wgIfaceStater: wgIfaceStater,
|
||||||
peerKey: peerKey,
|
peerKey: peerKey,
|
||||||
stateDump: stateDump,
|
stateDump: stateDump,
|
||||||
|
resetCh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
|
|||||||
return w.enabled
|
return w.enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||||
|
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||||
|
func (w *WGWatcher) Reset() {
|
||||||
|
select {
|
||||||
|
case w.resetCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||||
w.log.Infof("WireGuard watcher started")
|
w.log.Infof("WireGuard watcher started")
|
||||||
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
|||||||
w.stateDump.WGcheckSuccess()
|
w.stateDump.WGcheckSuccess()
|
||||||
|
|
||||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||||
|
case <-w.resetCh:
|
||||||
|
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||||
|
lastHandshake = time.Time{}
|
||||||
|
enabledTime = time.Now()
|
||||||
|
timer.Stop()
|
||||||
|
timer.Reset(wgHandshakeOvertime)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
w.log.Infof("WireGuard watcher stopped")
|
w.log.Infof("WireGuard watcher stopped")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -52,8 +52,9 @@ type WorkerICE struct {
|
|||||||
// increase by one when disconnecting the agent
|
// increase by one when disconnecting the agent
|
||||||
// with it the remote peer can discard the already deprecated offer/answer
|
// with it the remote peer can discard the already deprecated offer/answer
|
||||||
// Without it the remote peer may recreate a workable ICE connection
|
// Without it the remote peer may recreate a workable ICE connection
|
||||||
sessionID ICESessionID
|
sessionID ICESessionID
|
||||||
muxAgent sync.Mutex
|
remoteSessionChanged bool
|
||||||
|
muxAgent sync.Mutex
|
||||||
|
|
||||||
localUfrag string
|
localUfrag string
|
||||||
localPwd string
|
localPwd string
|
||||||
@@ -106,9 +107,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
|
w.remoteSessionChanged = true
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if err := w.agent.Close(); err != nil {
|
if w.agent != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
sessionID, err := NewICESessionID()
|
||||||
@@ -304,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||||
cancel()
|
cancel()
|
||||||
if err := agent.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
|
sessionChanged := w.remoteSessionChanged
|
||||||
|
w.remoteSessionChanged = false
|
||||||
|
|
||||||
if w.agent == agent {
|
if w.agent == agent {
|
||||||
// consider to remove from here and move to the OnNewOffer
|
// consider to remove from here and move to the OnNewOffer
|
||||||
@@ -323,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.remoteSessionID = ""
|
w.remoteSessionID = ""
|
||||||
}
|
}
|
||||||
w.muxAgent.Unlock()
|
return sessionChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
@@ -424,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
w.closeAgent(agent, dialerCancel)
|
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected(sessionChanged)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.AdminURL == nil {
|
if config.AdminURL == nil {
|
||||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
@@ -351,6 +351,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
|||||||
logger.Errorf("failed to update domain prefixes: %v", err)
|
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allow time for route changes to be applied before sending
|
||||||
|
// the DNS response (relevant on iOS where setTunnelNetworkSettings
|
||||||
|
// is asynchronous).
|
||||||
|
waitForRouteSettlement(logger)
|
||||||
|
|
||||||
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
20
client/internal/routemanager/dnsinterceptor/handler_ios.go
Normal file
20
client/internal/routemanager/dnsinterceptor/handler_ios.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package dnsinterceptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const routeSettleDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
|
// waitForRouteSettlement introduces a short delay on iOS to allow
|
||||||
|
// setTunnelNetworkSettings to apply route changes before the DNS
|
||||||
|
// response reaches the application. Without this, the first request
|
||||||
|
// to a newly resolved domain may bypass the tunnel.
|
||||||
|
func waitForRouteSettlement(logger *log.Entry) {
|
||||||
|
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
|
||||||
|
time.Sleep(routeSettleDelay)
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package dnsinterceptor
|
||||||
|
|
||||||
|
import log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
func waitForRouteSettlement(_ *log.Entry) {
|
||||||
|
// No-op on non-iOS platforms: route changes are applied synchronously by
|
||||||
|
// the kernel, so no settlement delay is needed before the DNS response
|
||||||
|
// reaches the application. The delay is only required on iOS where
|
||||||
|
// setTunnelNetworkSettings applies routes asynchronously.
|
||||||
|
}
|
||||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
|
var once sync.Once
|
||||||
|
var wgIface *net.Interface
|
||||||
|
toInterface := func() *net.Interface {
|
||||||
|
once.Do(func() {
|
||||||
|
wgIface = m.wgInterface.ToInterface()
|
||||||
|
})
|
||||||
|
return wgIface
|
||||||
|
}
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -337,6 +346,23 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
|
||||||
|
batchStarted := false
|
||||||
|
if m.dnsServer != nil {
|
||||||
|
m.dnsServer.BeginBatch()
|
||||||
|
batchStarted = true
|
||||||
|
defer func() {
|
||||||
|
if merr != nil {
|
||||||
|
// On error, cancel batch to discard partial DNS state
|
||||||
|
m.dnsServer.CancelBatch()
|
||||||
|
} else {
|
||||||
|
// On success, apply accumulated DNS changes
|
||||||
|
m.dnsServer.EndBatch()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
for id, handler := range toRemove {
|
for id, handler := range toRemove {
|
||||||
if err := handler.RemoveRoute(); err != nil {
|
if err := handler.RemoveRoute(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
||||||
@@ -367,6 +393,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
|||||||
m.activeRoutes[id] = handler
|
m.activeRoutes[id] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_ = batchStarted // Mark as used
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_CLONING != 0 {
|
if flags&unix.RTF_CLONING != 0 {
|
||||||
flagStrs = append(flagStrs, "C")
|
flagStrs = append(flagStrs, "C")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_WASCLONED != 0 {
|
if flags&unix.RTF_WASCLONED != 0 {
|
||||||
flagStrs = append(flagStrs, "W")
|
flagStrs = append(flagStrs, "W")
|
||||||
}
|
}
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -4,17 +4,18 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -2,7 +2,10 @@
|
|||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
// EnvList is an exported struct to be bound by gomobile
|
// EnvList is an exported struct to be bound by gomobile
|
||||||
type EnvList struct {
|
type EnvList struct {
|
||||||
@@ -32,3 +35,13 @@ func (el *EnvList) AllItems() map[string]string {
|
|||||||
func GetEnvKeyNBForceRelay() string {
|
func GetEnvKeyNBForceRelay() string {
|
||||||
return peer.EnvKeyNBForceRelay
|
return peer.EnvKeyNBForceRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
|
||||||
|
func GetEnvKeyNBLazyConn() string {
|
||||||
|
return lazyconn.EnvEnableLazyConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
|
||||||
|
func GetEnvKeyNBInactivityThreshold() string {
|
||||||
|
return lazyconn.EnvInactivityThreshold
|
||||||
|
}
|
||||||
|
|||||||
5
combined/Dockerfile
Normal file
5
combined/Dockerfile
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||||
|
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
COPY netbird-server /go/bin/netbird-server
|
||||||
25
combined/Dockerfile.multistage
Normal file
25
combined/Dockerfile.multistage
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev git && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Build with version info from git (matching goreleaser ldflags)
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build \
|
||||||
|
-ldflags="-s -w \
|
||||||
|
-X github.com/netbirdio/netbird/version.version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev') \
|
||||||
|
-X main.commit=$(git rev-parse --short HEAD 2>/dev/null || echo 'unknown') \
|
||||||
|
-X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ) \
|
||||||
|
-X main.builtBy=docker" \
|
||||||
|
-o netbird-server ./combined
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||||
|
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
COPY --from=builder /app/netbird-server /go/bin/netbird-server
|
||||||
661
combined/LICENSE
Normal file
661
combined/LICENSE
Normal file
@@ -0,0 +1,661 @@
|
|||||||
|
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||||
|
Version 3, 19 November 2007
|
||||||
|
|
||||||
|
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies
|
||||||
|
of this license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
Preamble
|
||||||
|
|
||||||
|
The GNU Affero General Public License is a free, copyleft license for
|
||||||
|
software and other kinds of works, specifically designed to ensure
|
||||||
|
cooperation with the community in the case of network server software.
|
||||||
|
|
||||||
|
The licenses for most software and other practical works are designed
|
||||||
|
to take away your freedom to share and change the works. By contrast,
|
||||||
|
our General Public Licenses are intended to guarantee your freedom to
|
||||||
|
share and change all versions of a program--to make sure it remains free
|
||||||
|
software for all its users.
|
||||||
|
|
||||||
|
When we speak of free software, we are referring to freedom, not
|
||||||
|
price. Our General Public Licenses are designed to make sure that you
|
||||||
|
have the freedom to distribute copies of free software (and charge for
|
||||||
|
them if you wish), that you receive source code or can get it if you
|
||||||
|
want it, that you can change the software or use pieces of it in new
|
||||||
|
free programs, and that you know you can do these things.
|
||||||
|
|
||||||
|
Developers that use our General Public Licenses protect your rights
|
||||||
|
with two steps: (1) assert copyright on the software, and (2) offer
|
||||||
|
you this License which gives you legal permission to copy, distribute
|
||||||
|
and/or modify the software.
|
||||||
|
|
||||||
|
A secondary benefit of defending all users' freedom is that
|
||||||
|
improvements made in alternate versions of the program, if they
|
||||||
|
receive widespread use, become available for other developers to
|
||||||
|
incorporate. Many developers of free software are heartened and
|
||||||
|
encouraged by the resulting cooperation. However, in the case of
|
||||||
|
software used on network servers, this result may fail to come about.
|
||||||
|
The GNU General Public License permits making a modified version and
|
||||||
|
letting the public access it on a server without ever releasing its
|
||||||
|
source code to the public.
|
||||||
|
|
||||||
|
The GNU Affero General Public License is designed specifically to
|
||||||
|
ensure that, in such cases, the modified source code becomes available
|
||||||
|
to the community. It requires the operator of a network server to
|
||||||
|
provide the source code of the modified version running there to the
|
||||||
|
users of that server. Therefore, public use of a modified version, on
|
||||||
|
a publicly accessible server, gives the public access to the source
|
||||||
|
code of the modified version.
|
||||||
|
|
||||||
|
An older license, called the Affero General Public License and
|
||||||
|
published by Affero, was designed to accomplish similar goals. This is
|
||||||
|
a different license, not a version of the Affero GPL, but Affero has
|
||||||
|
released a new version of the Affero GPL which permits relicensing under
|
||||||
|
this license.
|
||||||
|
|
||||||
|
The precise terms and conditions for copying, distribution and
|
||||||
|
modification follow.
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
0. Definitions.
|
||||||
|
|
||||||
|
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||||
|
|
||||||
|
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||||
|
works, such as semiconductor masks.
|
||||||
|
|
||||||
|
"The Program" refers to any copyrightable work licensed under this
|
||||||
|
License. Each licensee is addressed as "you". "Licensees" and
|
||||||
|
"recipients" may be individuals or organizations.
|
||||||
|
|
||||||
|
To "modify" a work means to copy from or adapt all or part of the work
|
||||||
|
in a fashion requiring copyright permission, other than the making of an
|
||||||
|
exact copy. The resulting work is called a "modified version" of the
|
||||||
|
earlier work or a work "based on" the earlier work.
|
||||||
|
|
||||||
|
A "covered work" means either the unmodified Program or a work based
|
||||||
|
on the Program.
|
||||||
|
|
||||||
|
To "propagate" a work means to do anything with it that, without
|
||||||
|
permission, would make you directly or secondarily liable for
|
||||||
|
infringement under applicable copyright law, except executing it on a
|
||||||
|
computer or modifying a private copy. Propagation includes copying,
|
||||||
|
distribution (with or without modification), making available to the
|
||||||
|
public, and in some countries other activities as well.
|
||||||
|
|
||||||
|
To "convey" a work means any kind of propagation that enables other
|
||||||
|
parties to make or receive copies. Mere interaction with a user through
|
||||||
|
a computer network, with no transfer of a copy, is not conveying.
|
||||||
|
|
||||||
|
An interactive user interface displays "Appropriate Legal Notices"
|
||||||
|
to the extent that it includes a convenient and prominently visible
|
||||||
|
feature that (1) displays an appropriate copyright notice, and (2)
|
||||||
|
tells the user that there is no warranty for the work (except to the
|
||||||
|
extent that warranties are provided), that licensees may convey the
|
||||||
|
work under this License, and how to view a copy of this License. If
|
||||||
|
the interface presents a list of user commands or options, such as a
|
||||||
|
menu, a prominent item in the list meets this criterion.
|
||||||
|
|
||||||
|
1. Source Code.
|
||||||
|
|
||||||
|
The "source code" for a work means the preferred form of the work
|
||||||
|
for making modifications to it. "Object code" means any non-source
|
||||||
|
form of a work.
|
||||||
|
|
||||||
|
A "Standard Interface" means an interface that either is an official
|
||||||
|
standard defined by a recognized standards body, or, in the case of
|
||||||
|
interfaces specified for a particular programming language, one that
|
||||||
|
is widely used among developers working in that language.
|
||||||
|
|
||||||
|
The "System Libraries" of an executable work include anything, other
|
||||||
|
than the work as a whole, that (a) is included in the normal form of
|
||||||
|
packaging a Major Component, but which is not part of that Major
|
||||||
|
Component, and (b) serves only to enable use of the work with that
|
||||||
|
Major Component, or to implement a Standard Interface for which an
|
||||||
|
implementation is available to the public in source code form. A
|
||||||
|
"Major Component", in this context, means a major essential component
|
||||||
|
(kernel, window system, and so on) of the specific operating system
|
||||||
|
(if any) on which the executable work runs, or a compiler used to
|
||||||
|
produce the work, or an object code interpreter used to run it.
|
||||||
|
|
||||||
|
The "Corresponding Source" for a work in object code form means all
|
||||||
|
the source code needed to generate, install, and (for an executable
|
||||||
|
work) run the object code and to modify the work, including scripts to
|
||||||
|
control those activities. However, it does not include the work's
|
||||||
|
System Libraries, or general-purpose tools or generally available free
|
||||||
|
programs which are used unmodified in performing those activities but
|
||||||
|
which are not part of the work. For example, Corresponding Source
|
||||||
|
includes interface definition files associated with source files for
|
||||||
|
the work, and the source code for shared libraries and dynamically
|
||||||
|
linked subprograms that the work is specifically designed to require,
|
||||||
|
such as by intimate data communication or control flow between those
|
||||||
|
subprograms and other parts of the work.
|
||||||
|
|
||||||
|
The Corresponding Source need not include anything that users
|
||||||
|
can regenerate automatically from other parts of the Corresponding
|
||||||
|
Source.
|
||||||
|
|
||||||
|
The Corresponding Source for a work in source code form is that
|
||||||
|
same work.
|
||||||
|
|
||||||
|
2. Basic Permissions.
|
||||||
|
|
||||||
|
All rights granted under this License are granted for the term of
|
||||||
|
copyright on the Program, and are irrevocable provided the stated
|
||||||
|
conditions are met. This License explicitly affirms your unlimited
|
||||||
|
permission to run the unmodified Program. The output from running a
|
||||||
|
covered work is covered by this License only if the output, given its
|
||||||
|
content, constitutes a covered work. This License acknowledges your
|
||||||
|
rights of fair use or other equivalent, as provided by copyright law.
|
||||||
|
|
||||||
|
You may make, run and propagate covered works that you do not
|
||||||
|
convey, without conditions so long as your license otherwise remains
|
||||||
|
in force. You may convey covered works to others for the sole purpose
|
||||||
|
of having them make modifications exclusively for you, or provide you
|
||||||
|
with facilities for running those works, provided that you comply with
|
||||||
|
the terms of this License in conveying all material for which you do
|
||||||
|
not control copyright. Those thus making or running the covered works
|
||||||
|
for you must do so exclusively on your behalf, under your direction
|
||||||
|
and control, on terms that prohibit them from making any copies of
|
||||||
|
your copyrighted material outside their relationship with you.
|
||||||
|
|
||||||
|
Conveying under any other circumstances is permitted solely under
|
||||||
|
the conditions stated below. Sublicensing is not allowed; section 10
|
||||||
|
makes it unnecessary.
|
||||||
|
|
||||||
|
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||||
|
|
||||||
|
No covered work shall be deemed part of an effective technological
|
||||||
|
measure under any applicable law fulfilling obligations under article
|
||||||
|
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||||
|
similar laws prohibiting or restricting circumvention of such
|
||||||
|
measures.
|
||||||
|
|
||||||
|
When you convey a covered work, you waive any legal power to forbid
|
||||||
|
circumvention of technological measures to the extent such circumvention
|
||||||
|
is effected by exercising rights under this License with respect to
|
||||||
|
the covered work, and you disclaim any intention to limit operation or
|
||||||
|
modification of the work as a means of enforcing, against the work's
|
||||||
|
users, your or third parties' legal rights to forbid circumvention of
|
||||||
|
technological measures.
|
||||||
|
|
||||||
|
4. Conveying Verbatim Copies.
|
||||||
|
|
||||||
|
You may convey verbatim copies of the Program's source code as you
|
||||||
|
receive it, in any medium, provided that you conspicuously and
|
||||||
|
appropriately publish on each copy an appropriate copyright notice;
|
||||||
|
keep intact all notices stating that this License and any
|
||||||
|
non-permissive terms added in accord with section 7 apply to the code;
|
||||||
|
keep intact all notices of the absence of any warranty; and give all
|
||||||
|
recipients a copy of this License along with the Program.
|
||||||
|
|
||||||
|
You may charge any price or no price for each copy that you convey,
|
||||||
|
and you may offer support or warranty protection for a fee.
|
||||||
|
|
||||||
|
5. Conveying Modified Source Versions.
|
||||||
|
|
||||||
|
You may convey a work based on the Program, or the modifications to
|
||||||
|
produce it from the Program, in the form of source code under the
|
||||||
|
terms of section 4, provided that you also meet all of these conditions:
|
||||||
|
|
||||||
|
a) The work must carry prominent notices stating that you modified
|
||||||
|
it, and giving a relevant date.
|
||||||
|
|
||||||
|
b) The work must carry prominent notices stating that it is
|
||||||
|
released under this License and any conditions added under section
|
||||||
|
7. This requirement modifies the requirement in section 4 to
|
||||||
|
"keep intact all notices".
|
||||||
|
|
||||||
|
c) You must license the entire work, as a whole, under this
|
||||||
|
License to anyone who comes into possession of a copy. This
|
||||||
|
License will therefore apply, along with any applicable section 7
|
||||||
|
additional terms, to the whole of the work, and all its parts,
|
||||||
|
regardless of how they are packaged. This License gives no
|
||||||
|
permission to license the work in any other way, but it does not
|
||||||
|
invalidate such permission if you have separately received it.
|
||||||
|
|
||||||
|
d) If the work has interactive user interfaces, each must display
|
||||||
|
Appropriate Legal Notices; however, if the Program has interactive
|
||||||
|
interfaces that do not display Appropriate Legal Notices, your
|
||||||
|
work need not make them do so.
|
||||||
|
|
||||||
|
A compilation of a covered work with other separate and independent
|
||||||
|
works, which are not by their nature extensions of the covered work,
|
||||||
|
and which are not combined with it such as to form a larger program,
|
||||||
|
in or on a volume of a storage or distribution medium, is called an
|
||||||
|
"aggregate" if the compilation and its resulting copyright are not
|
||||||
|
used to limit the access or legal rights of the compilation's users
|
||||||
|
beyond what the individual works permit. Inclusion of a covered work
|
||||||
|
in an aggregate does not cause this License to apply to the other
|
||||||
|
parts of the aggregate.
|
||||||
|
|
||||||
|
6. Conveying Non-Source Forms.
|
||||||
|
|
||||||
|
You may convey a covered work in object code form under the terms
|
||||||
|
of sections 4 and 5, provided that you also convey the
|
||||||
|
machine-readable Corresponding Source under the terms of this License,
|
||||||
|
in one of these ways:
|
||||||
|
|
||||||
|
a) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by the
|
||||||
|
Corresponding Source fixed on a durable physical medium
|
||||||
|
customarily used for software interchange.
|
||||||
|
|
||||||
|
b) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by a
|
||||||
|
written offer, valid for at least three years and valid for as
|
||||||
|
long as you offer spare parts or customer support for that product
|
||||||
|
model, to give anyone who possesses the object code either (1) a
|
||||||
|
copy of the Corresponding Source for all the software in the
|
||||||
|
product that is covered by this License, on a durable physical
|
||||||
|
medium customarily used for software interchange, for a price no
|
||||||
|
more than your reasonable cost of physically performing this
|
||||||
|
conveying of source, or (2) access to copy the
|
||||||
|
Corresponding Source from a network server at no charge.
|
||||||
|
|
||||||
|
c) Convey individual copies of the object code with a copy of the
|
||||||
|
written offer to provide the Corresponding Source. This
|
||||||
|
alternative is allowed only occasionally and noncommercially, and
|
||||||
|
only if you received the object code with such an offer, in accord
|
||||||
|
with subsection 6b.
|
||||||
|
|
||||||
|
d) Convey the object code by offering access from a designated
|
||||||
|
place (gratis or for a charge), and offer equivalent access to the
|
||||||
|
Corresponding Source in the same way through the same place at no
|
||||||
|
further charge. You need not require recipients to copy the
|
||||||
|
Corresponding Source along with the object code. If the place to
|
||||||
|
copy the object code is a network server, the Corresponding Source
|
||||||
|
may be on a different server (operated by you or a third party)
|
||||||
|
that supports equivalent copying facilities, provided you maintain
|
||||||
|
clear directions next to the object code saying where to find the
|
||||||
|
Corresponding Source. Regardless of what server hosts the
|
||||||
|
Corresponding Source, you remain obligated to ensure that it is
|
||||||
|
available for as long as needed to satisfy these requirements.
|
||||||
|
|
||||||
|
e) Convey the object code using peer-to-peer transmission, provided
|
||||||
|
you inform other peers where the object code and Corresponding
|
||||||
|
Source of the work are being offered to the general public at no
|
||||||
|
charge under subsection 6d.
|
||||||
|
|
||||||
|
A separable portion of the object code, whose source code is excluded
|
||||||
|
from the Corresponding Source as a System Library, need not be
|
||||||
|
included in conveying the object code work.
|
||||||
|
|
||||||
|
A "User Product" is either (1) a "consumer product", which means any
|
||||||
|
tangible personal property which is normally used for personal, family,
|
||||||
|
or household purposes, or (2) anything designed or sold for incorporation
|
||||||
|
into a dwelling. In determining whether a product is a consumer product,
|
||||||
|
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||||
|
product received by a particular user, "normally used" refers to a
|
||||||
|
typical or common use of that class of product, regardless of the status
|
||||||
|
of the particular user or of the way in which the particular user
|
||||||
|
actually uses, or expects or is expected to use, the product. A product
|
||||||
|
is a consumer product regardless of whether the product has substantial
|
||||||
|
commercial, industrial or non-consumer uses, unless such uses represent
|
||||||
|
the only significant mode of use of the product.
|
||||||
|
|
||||||
|
"Installation Information" for a User Product means any methods,
|
||||||
|
procedures, authorization keys, or other information required to install
|
||||||
|
and execute modified versions of a covered work in that User Product from
|
||||||
|
a modified version of its Corresponding Source. The information must
|
||||||
|
suffice to ensure that the continued functioning of the modified object
|
||||||
|
code is in no case prevented or interfered with solely because
|
||||||
|
modification has been made.
|
||||||
|
|
||||||
|
If you convey an object code work under this section in, or with, or
|
||||||
|
specifically for use in, a User Product, and the conveying occurs as
|
||||||
|
part of a transaction in which the right of possession and use of the
|
||||||
|
User Product is transferred to the recipient in perpetuity or for a
|
||||||
|
fixed term (regardless of how the transaction is characterized), the
|
||||||
|
Corresponding Source conveyed under this section must be accompanied
|
||||||
|
by the Installation Information. But this requirement does not apply
|
||||||
|
if neither you nor any third party retains the ability to install
|
||||||
|
modified object code on the User Product (for example, the work has
|
||||||
|
been installed in ROM).
|
||||||
|
|
||||||
|
The requirement to provide Installation Information does not include a
|
||||||
|
requirement to continue to provide support service, warranty, or updates
|
||||||
|
for a work that has been modified or installed by the recipient, or for
|
||||||
|
the User Product in which it has been modified or installed. Access to a
|
||||||
|
network may be denied when the modification itself materially and
|
||||||
|
adversely affects the operation of the network or violates the rules and
|
||||||
|
protocols for communication across the network.
|
||||||
|
|
||||||
|
Corresponding Source conveyed, and Installation Information provided,
|
||||||
|
in accord with this section must be in a format that is publicly
|
||||||
|
documented (and with an implementation available to the public in
|
||||||
|
source code form), and must require no special password or key for
|
||||||
|
unpacking, reading or copying.
|
||||||
|
|
||||||
|
7. Additional Terms.
|
||||||
|
|
||||||
|
"Additional permissions" are terms that supplement the terms of this
|
||||||
|
License by making exceptions from one or more of its conditions.
|
||||||
|
Additional permissions that are applicable to the entire Program shall
|
||||||
|
be treated as though they were included in this License, to the extent
|
||||||
|
that they are valid under applicable law. If additional permissions
|
||||||
|
apply only to part of the Program, that part may be used separately
|
||||||
|
under those permissions, but the entire Program remains governed by
|
||||||
|
this License without regard to the additional permissions.
|
||||||
|
|
||||||
|
When you convey a copy of a covered work, you may at your option
|
||||||
|
remove any additional permissions from that copy, or from any part of
|
||||||
|
it. (Additional permissions may be written to require their own
|
||||||
|
removal in certain cases when you modify the work.) You may place
|
||||||
|
additional permissions on material, added by you to a covered work,
|
||||||
|
for which you have or can give appropriate copyright permission.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, for material you
|
||||||
|
add to a covered work, you may (if authorized by the copyright holders of
|
||||||
|
that material) supplement the terms of this License with terms:
|
||||||
|
|
||||||
|
a) Disclaiming warranty or limiting liability differently from the
|
||||||
|
terms of sections 15 and 16 of this License; or
|
||||||
|
|
||||||
|
b) Requiring preservation of specified reasonable legal notices or
|
||||||
|
author attributions in that material or in the Appropriate Legal
|
||||||
|
Notices displayed by works containing it; or
|
||||||
|
|
||||||
|
c) Prohibiting misrepresentation of the origin of that material, or
|
||||||
|
requiring that modified versions of such material be marked in
|
||||||
|
reasonable ways as different from the original version; or
|
||||||
|
|
||||||
|
d) Limiting the use for publicity purposes of names of licensors or
|
||||||
|
authors of the material; or
|
||||||
|
|
||||||
|
e) Declining to grant rights under trademark law for use of some
|
||||||
|
trade names, trademarks, or service marks; or
|
||||||
|
|
||||||
|
f) Requiring indemnification of licensors and authors of that
|
||||||
|
material by anyone who conveys the material (or modified versions of
|
||||||
|
it) with contractual assumptions of liability to the recipient, for
|
||||||
|
any liability that these contractual assumptions directly impose on
|
||||||
|
those licensors and authors.
|
||||||
|
|
||||||
|
All other non-permissive additional terms are considered "further
|
||||||
|
restrictions" within the meaning of section 10. If the Program as you
|
||||||
|
received it, or any part of it, contains a notice stating that it is
|
||||||
|
governed by this License along with a term that is a further
|
||||||
|
restriction, you may remove that term. If a license document contains
|
||||||
|
a further restriction but permits relicensing or conveying under this
|
||||||
|
License, you may add to a covered work material governed by the terms
|
||||||
|
of that license document, provided that the further restriction does
|
||||||
|
not survive such relicensing or conveying.
|
||||||
|
|
||||||
|
If you add terms to a covered work in accord with this section, you
|
||||||
|
must place, in the relevant source files, a statement of the
|
||||||
|
additional terms that apply to those files, or a notice indicating
|
||||||
|
where to find the applicable terms.
|
||||||
|
|
||||||
|
Additional terms, permissive or non-permissive, may be stated in the
|
||||||
|
form of a separately written license, or stated as exceptions;
|
||||||
|
the above requirements apply either way.
|
||||||
|
|
||||||
|
8. Termination.
|
||||||
|
|
||||||
|
You may not propagate or modify a covered work except as expressly
|
||||||
|
provided under this License. Any attempt otherwise to propagate or
|
||||||
|
modify it is void, and will automatically terminate your rights under
|
||||||
|
this License (including any patent licenses granted under the third
|
||||||
|
paragraph of section 11).
|
||||||
|
|
||||||
|
However, if you cease all violation of this License, then your
|
||||||
|
license from a particular copyright holder is reinstated (a)
|
||||||
|
provisionally, unless and until the copyright holder explicitly and
|
||||||
|
finally terminates your license, and (b) permanently, if the copyright
|
||||||
|
holder fails to notify you of the violation by some reasonable means
|
||||||
|
prior to 60 days after the cessation.
|
||||||
|
|
||||||
|
Moreover, your license from a particular copyright holder is
|
||||||
|
reinstated permanently if the copyright holder notifies you of the
|
||||||
|
violation by some reasonable means, this is the first time you have
|
||||||
|
received notice of violation of this License (for any work) from that
|
||||||
|
copyright holder, and you cure the violation prior to 30 days after
|
||||||
|
your receipt of the notice.
|
||||||
|
|
||||||
|
Termination of your rights under this section does not terminate the
|
||||||
|
licenses of parties who have received copies or rights from you under
|
||||||
|
this License. If your rights have been terminated and not permanently
|
||||||
|
reinstated, you do not qualify to receive new licenses for the same
|
||||||
|
material under section 10.
|
||||||
|
|
||||||
|
9. Acceptance Not Required for Having Copies.
|
||||||
|
|
||||||
|
You are not required to accept this License in order to receive or
|
||||||
|
run a copy of the Program. Ancillary propagation of a covered work
|
||||||
|
occurring solely as a consequence of using peer-to-peer transmission
|
||||||
|
to receive a copy likewise does not require acceptance. However,
|
||||||
|
nothing other than this License grants you permission to propagate or
|
||||||
|
modify any covered work. These actions infringe copyright if you do
|
||||||
|
not accept this License. Therefore, by modifying or propagating a
|
||||||
|
covered work, you indicate your acceptance of this License to do so.
|
||||||
|
|
||||||
|
10. Automatic Licensing of Downstream Recipients.
|
||||||
|
|
||||||
|
Each time you convey a covered work, the recipient automatically
|
||||||
|
receives a license from the original licensors, to run, modify and
|
||||||
|
propagate that work, subject to this License. You are not responsible
|
||||||
|
for enforcing compliance by third parties with this License.
|
||||||
|
|
||||||
|
An "entity transaction" is a transaction transferring control of an
|
||||||
|
organization, or substantially all assets of one, or subdividing an
|
||||||
|
organization, or merging organizations. If propagation of a covered
|
||||||
|
work results from an entity transaction, each party to that
|
||||||
|
transaction who receives a copy of the work also receives whatever
|
||||||
|
licenses to the work the party's predecessor in interest had or could
|
||||||
|
give under the previous paragraph, plus a right to possession of the
|
||||||
|
Corresponding Source of the work from the predecessor in interest, if
|
||||||
|
the predecessor has it or can get it with reasonable efforts.
|
||||||
|
|
||||||
|
You may not impose any further restrictions on the exercise of the
|
||||||
|
rights granted or affirmed under this License. For example, you may
|
||||||
|
not impose a license fee, royalty, or other charge for exercise of
|
||||||
|
rights granted under this License, and you may not initiate litigation
|
||||||
|
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||||
|
any patent claim is infringed by making, using, selling, offering for
|
||||||
|
sale, or importing the Program or any portion of it.
|
||||||
|
|
||||||
|
11. Patents.
|
||||||
|
|
||||||
|
A "contributor" is a copyright holder who authorizes use under this
|
||||||
|
License of the Program or a work on which the Program is based. The
|
||||||
|
work thus licensed is called the contributor's "contributor version".
|
||||||
|
|
||||||
|
A contributor's "essential patent claims" are all patent claims
|
||||||
|
owned or controlled by the contributor, whether already acquired or
|
||||||
|
hereafter acquired, that would be infringed by some manner, permitted
|
||||||
|
by this License, of making, using, or selling its contributor version,
|
||||||
|
but do not include claims that would be infringed only as a
|
||||||
|
consequence of further modification of the contributor version. For
|
||||||
|
purposes of this definition, "control" includes the right to grant
|
||||||
|
patent sublicenses in a manner consistent with the requirements of
|
||||||
|
this License.
|
||||||
|
|
||||||
|
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||||
|
patent license under the contributor's essential patent claims, to
|
||||||
|
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||||
|
propagate the contents of its contributor version.
|
||||||
|
|
||||||
|
In the following three paragraphs, a "patent license" is any express
|
||||||
|
agreement or commitment, however denominated, not to enforce a patent
|
||||||
|
(such as an express permission to practice a patent or covenant not to
|
||||||
|
sue for patent infringement). To "grant" such a patent license to a
|
||||||
|
party means to make such an agreement or commitment not to enforce a
|
||||||
|
patent against the party.
|
||||||
|
|
||||||
|
If you convey a covered work, knowingly relying on a patent license,
|
||||||
|
and the Corresponding Source of the work is not available for anyone
|
||||||
|
to copy, free of charge and under the terms of this License, through a
|
||||||
|
publicly available network server or other readily accessible means,
|
||||||
|
then you must either (1) cause the Corresponding Source to be so
|
||||||
|
available, or (2) arrange to deprive yourself of the benefit of the
|
||||||
|
patent license for this particular work, or (3) arrange, in a manner
|
||||||
|
consistent with the requirements of this License, to extend the patent
|
||||||
|
license to downstream recipients. "Knowingly relying" means you have
|
||||||
|
actual knowledge that, but for the patent license, your conveying the
|
||||||
|
covered work in a country, or your recipient's use of the covered work
|
||||||
|
in a country, would infringe one or more identifiable patents in that
|
||||||
|
country that you have reason to believe are valid.
|
||||||
|
|
||||||
|
If, pursuant to or in connection with a single transaction or
|
||||||
|
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||||
|
covered work, and grant a patent license to some of the parties
|
||||||
|
receiving the covered work authorizing them to use, propagate, modify
|
||||||
|
or convey a specific copy of the covered work, then the patent license
|
||||||
|
you grant is automatically extended to all recipients of the covered
|
||||||
|
work and works based on it.
|
||||||
|
|
||||||
|
A patent license is "discriminatory" if it does not include within
|
||||||
|
the scope of its coverage, prohibits the exercise of, or is
|
||||||
|
conditioned on the non-exercise of one or more of the rights that are
|
||||||
|
specifically granted under this License. You may not convey a covered
|
||||||
|
work if you are a party to an arrangement with a third party that is
|
||||||
|
in the business of distributing software, under which you make payment
|
||||||
|
to the third party based on the extent of your activity of conveying
|
||||||
|
the work, and under which the third party grants, to any of the
|
||||||
|
parties who would receive the covered work from you, a discriminatory
|
||||||
|
patent license (a) in connection with copies of the covered work
|
||||||
|
conveyed by you (or copies made from those copies), or (b) primarily
|
||||||
|
for and in connection with specific products or compilations that
|
||||||
|
contain the covered work, unless you entered into that arrangement,
|
||||||
|
or that patent license was granted, prior to 28 March 2007.
|
||||||
|
|
||||||
|
Nothing in this License shall be construed as excluding or limiting
|
||||||
|
any implied license or other defenses to infringement that may
|
||||||
|
otherwise be available to you under applicable patent law.
|
||||||
|
|
||||||
|
12. No Surrender of Others' Freedom.
|
||||||
|
|
||||||
|
If conditions are imposed on you (whether by court order, agreement or
|
||||||
|
otherwise) that contradict the conditions of this License, they do not
|
||||||
|
excuse you from the conditions of this License. If you cannot convey a
|
||||||
|
covered work so as to satisfy simultaneously your obligations under this
|
||||||
|
License and any other pertinent obligations, then as a consequence you may
|
||||||
|
not convey it at all. For example, if you agree to terms that obligate you
|
||||||
|
to collect a royalty for further conveying from those to whom you convey
|
||||||
|
the Program, the only way you could satisfy both those terms and this
|
||||||
|
License would be to refrain entirely from conveying the Program.
|
||||||
|
|
||||||
|
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, if you modify the
|
||||||
|
Program, your modified version must prominently offer all users
|
||||||
|
interacting with it remotely through a computer network (if your version
|
||||||
|
supports such interaction) an opportunity to receive the Corresponding
|
||||||
|
Source of your version by providing access to the Corresponding Source
|
||||||
|
from a network server at no charge, through some standard or customary
|
||||||
|
means of facilitating copying of software. This Corresponding Source
|
||||||
|
shall include the Corresponding Source for any work covered by version 3
|
||||||
|
of the GNU General Public License that is incorporated pursuant to the
|
||||||
|
following paragraph.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, you have
|
||||||
|
permission to link or combine any covered work with a work licensed
|
||||||
|
under version 3 of the GNU General Public License into a single
|
||||||
|
combined work, and to convey the resulting work. The terms of this
|
||||||
|
License will continue to apply to the part which is the covered work,
|
||||||
|
but the work with which it is combined will remain governed by version
|
||||||
|
3 of the GNU General Public License.
|
||||||
|
|
||||||
|
14. Revised Versions of this License.
|
||||||
|
|
||||||
|
The Free Software Foundation may publish revised and/or new versions of
|
||||||
|
the GNU Affero General Public License from time to time. Such new versions
|
||||||
|
will be similar in spirit to the present version, but may differ in detail to
|
||||||
|
address new problems or concerns.
|
||||||
|
|
||||||
|
Each version is given a distinguishing version number. If the
|
||||||
|
Program specifies that a certain numbered version of the GNU Affero General
|
||||||
|
Public License "or any later version" applies to it, you have the
|
||||||
|
option of following the terms and conditions either of that numbered
|
||||||
|
version or of any later version published by the Free Software
|
||||||
|
Foundation. If the Program does not specify a version number of the
|
||||||
|
GNU Affero General Public License, you may choose any version ever published
|
||||||
|
by the Free Software Foundation.
|
||||||
|
|
||||||
|
If the Program specifies that a proxy can decide which future
|
||||||
|
versions of the GNU Affero General Public License can be used, that proxy's
|
||||||
|
public statement of acceptance of a version permanently authorizes you
|
||||||
|
to choose that version for the Program.
|
||||||
|
|
||||||
|
Later license versions may give you additional or different
|
||||||
|
permissions. However, no additional obligations are imposed on any
|
||||||
|
author or copyright holder as a result of your choosing to follow a
|
||||||
|
later version.
|
||||||
|
|
||||||
|
15. Disclaimer of Warranty.
|
||||||
|
|
||||||
|
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||||
|
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||||
|
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||||
|
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||||
|
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||||
|
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||||
|
|
||||||
|
16. Limitation of Liability.
|
||||||
|
|
||||||
|
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||||
|
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||||
|
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||||
|
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||||
|
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||||
|
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||||
|
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||||
|
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||||
|
SUCH DAMAGES.
|
||||||
|
|
||||||
|
17. Interpretation of Sections 15 and 16.
|
||||||
|
|
||||||
|
If the disclaimer of warranty and limitation of liability provided
|
||||||
|
above cannot be given local legal effect according to their terms,
|
||||||
|
reviewing courts shall apply local law that most closely approximates
|
||||||
|
an absolute waiver of all civil liability in connection with the
|
||||||
|
Program, unless a warranty or assumption of liability accompanies a
|
||||||
|
copy of the Program in return for a fee.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
How to Apply These Terms to Your New Programs
|
||||||
|
|
||||||
|
If you develop a new program, and you want it to be of the greatest
|
||||||
|
possible use to the public, the best way to achieve this is to make it
|
||||||
|
free software which everyone can redistribute and change under these terms.
|
||||||
|
|
||||||
|
To do so, attach the following notices to the program. It is safest
|
||||||
|
to attach them to the start of each source file to most effectively
|
||||||
|
state the exclusion of warranty; and each file should have at least
|
||||||
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
|
<one line to give the program's name and a brief idea of what it does.>
|
||||||
|
Copyright (C) <year> <name of author>
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
Also add information on how to contact you by electronic and paper mail.
|
||||||
|
|
||||||
|
If your software can interact with users remotely through a computer
|
||||||
|
network, you should also make sure that it provides a way for users to
|
||||||
|
get its source. For example, if your program is a web application, its
|
||||||
|
interface could display a "Source" link that leads users to an archive
|
||||||
|
of the code. There are many ways you could offer source, and different
|
||||||
|
solutions will be better for different programs; see section 13 for the
|
||||||
|
specific requirements.
|
||||||
|
|
||||||
|
You should also get your employer (if you work as a programmer) or school,
|
||||||
|
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||||
|
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||||
|
<https://www.gnu.org/licenses/>.
|
||||||
723
combined/cmd/config.go
Normal file
723
combined/cmd/config.go
Normal file
@@ -0,0 +1,723 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
|
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CombinedConfig is the root configuration for the combined server.
|
||||||
|
// The combined server is primarily a Management server with optional embedded
|
||||||
|
// Signal, Relay, and STUN services.
|
||||||
|
//
|
||||||
|
// Architecture:
|
||||||
|
// - Management: Always runs locally (this IS the management server)
|
||||||
|
// - Signal: Runs locally by default; disabled if server.signalUri is set
|
||||||
|
// - Relay: Runs locally by default; disabled if server.relays is set
|
||||||
|
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
|
||||||
|
//
|
||||||
|
// All user-facing settings are under "server". The relay/signal/management
|
||||||
|
// fields are internal and populated automatically from server settings.
|
||||||
|
type CombinedConfig struct {
|
||||||
|
Server ServerConfig `yaml:"server"`
|
||||||
|
|
||||||
|
// Internal configs - populated from Server settings, not user-configurable
|
||||||
|
Relay RelayConfig `yaml:"-"`
|
||||||
|
Signal SignalConfig `yaml:"-"`
|
||||||
|
Management ManagementConfig `yaml:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig contains server-wide settings
|
||||||
|
// In simplified mode, this contains all configuration
|
||||||
|
type ServerConfig struct {
|
||||||
|
ListenAddress string `yaml:"listenAddress"`
|
||||||
|
MetricsPort int `yaml:"metricsPort"`
|
||||||
|
HealthcheckAddress string `yaml:"healthcheckAddress"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
LogFile string `yaml:"logFile"`
|
||||||
|
TLS TLSConfig `yaml:"tls"`
|
||||||
|
|
||||||
|
// Simplified config fields (used when relay/signal/management sections are omitted)
|
||||||
|
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
|
||||||
|
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
|
||||||
|
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
|
||||||
|
DataDir string `yaml:"dataDir"` // Data directory for all services
|
||||||
|
|
||||||
|
// External service overrides (simplified mode)
|
||||||
|
// When these are set, the corresponding local service is NOT started
|
||||||
|
// and these values are used for client configuration instead
|
||||||
|
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
|
||||||
|
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
|
||||||
|
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
|
||||||
|
|
||||||
|
// Management settings (simplified mode)
|
||||||
|
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
|
||||||
|
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||||
|
Auth AuthConfig `yaml:"auth"`
|
||||||
|
Store StoreConfig `yaml:"store"`
|
||||||
|
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfig contains TLS/HTTPS settings
|
||||||
|
type TLSConfig struct {
|
||||||
|
CertFile string `yaml:"certFile"`
|
||||||
|
KeyFile string `yaml:"keyFile"`
|
||||||
|
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LetsEncryptConfig contains Let's Encrypt settings
|
||||||
|
type LetsEncryptConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
DataDir string `yaml:"dataDir"`
|
||||||
|
Domains []string `yaml:"domains"`
|
||||||
|
Email string `yaml:"email"`
|
||||||
|
AWSRoute53 bool `yaml:"awsRoute53"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelayConfig contains relay service settings
|
||||||
|
type RelayConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
ExposedAddress string `yaml:"exposedAddress"`
|
||||||
|
AuthSecret string `yaml:"authSecret"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
Stun StunConfig `yaml:"stun"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StunConfig contains embedded STUN service settings
|
||||||
|
type StunConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Ports []int `yaml:"ports"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignalConfig contains signal service settings
|
||||||
|
type SignalConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagementConfig contains management service settings
|
||||||
|
type ManagementConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
DataDir string `yaml:"dataDir"`
|
||||||
|
DnsDomain string `yaml:"dnsDomain"`
|
||||||
|
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
|
||||||
|
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||||
|
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
|
||||||
|
Auth AuthConfig `yaml:"auth"`
|
||||||
|
Stuns []HostConfig `yaml:"stuns"`
|
||||||
|
Relays RelaysConfig `yaml:"relays"`
|
||||||
|
SignalURI string `yaml:"signalUri"`
|
||||||
|
Store StoreConfig `yaml:"store"`
|
||||||
|
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthConfig contains authentication/identity provider settings
|
||||||
|
type AuthConfig struct {
|
||||||
|
Issuer string `yaml:"issuer"`
|
||||||
|
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
|
||||||
|
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
|
||||||
|
Storage AuthStorageConfig `yaml:"storage"`
|
||||||
|
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
|
||||||
|
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
|
||||||
|
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthStorageConfig contains auth storage settings
|
||||||
|
type AuthStorageConfig struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
File string `yaml:"file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthOwnerConfig contains initial admin user settings
|
||||||
|
type AuthOwnerConfig struct {
|
||||||
|
Email string `yaml:"email"`
|
||||||
|
Password string `yaml:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostConfig represents a STUN/TURN/Signal host
|
||||||
|
type HostConfig struct {
|
||||||
|
URI string `yaml:"uri"`
|
||||||
|
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
|
||||||
|
Username string `yaml:"username,omitempty"`
|
||||||
|
Password string `yaml:"password,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelaysConfig contains external relay server settings for clients
|
||||||
|
type RelaysConfig struct {
|
||||||
|
Addresses []string `yaml:"addresses"`
|
||||||
|
CredentialsTTL string `yaml:"credentialsTTL"`
|
||||||
|
Secret string `yaml:"secret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreConfig contains database settings
|
||||||
|
type StoreConfig struct {
|
||||||
|
Engine string `yaml:"engine"`
|
||||||
|
EncryptionKey string `yaml:"encryptionKey"`
|
||||||
|
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyConfig contains reverse proxy settings
|
||||||
|
type ReverseProxyConfig struct {
|
||||||
|
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
|
||||||
|
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
|
||||||
|
TrustedPeers []string `yaml:"trustedPeers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a CombinedConfig with default values
|
||||||
|
func DefaultConfig() *CombinedConfig {
|
||||||
|
return &CombinedConfig{
|
||||||
|
Server: ServerConfig{
|
||||||
|
ListenAddress: ":443",
|
||||||
|
MetricsPort: 9090,
|
||||||
|
HealthcheckAddress: ":9000",
|
||||||
|
LogLevel: "info",
|
||||||
|
LogFile: "console",
|
||||||
|
StunPorts: []int{3478},
|
||||||
|
DataDir: "/var/lib/netbird/",
|
||||||
|
Auth: AuthConfig{
|
||||||
|
Storage: AuthStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Store: StoreConfig{
|
||||||
|
Engine: "sqlite",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Relay: RelayConfig{
|
||||||
|
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
|
||||||
|
Stun: StunConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Ports: []int{3478},
|
||||||
|
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Signal: SignalConfig{
|
||||||
|
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
|
||||||
|
},
|
||||||
|
Management: ManagementConfig{
|
||||||
|
DataDir: "/var/lib/netbird/",
|
||||||
|
Auth: AuthConfig{
|
||||||
|
Storage: AuthStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Relays: RelaysConfig{
|
||||||
|
CredentialsTTL: "12h",
|
||||||
|
},
|
||||||
|
Store: StoreConfig{
|
||||||
|
Engine: "sqlite",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasRequiredSettings returns true if the configuration has the required server settings
|
||||||
|
func (c *CombinedConfig) hasRequiredSettings() bool {
|
||||||
|
return c.Server.ExposedAddress != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
|
||||||
|
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
|
||||||
|
// Returns: protocol ("https" or "http"), hostname only, and host:port
|
||||||
|
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
|
||||||
|
// Default to https if no protocol specified
|
||||||
|
protocol = "https"
|
||||||
|
hostPort = exposedAddress
|
||||||
|
|
||||||
|
// Check for protocol prefix
|
||||||
|
if strings.HasPrefix(exposedAddress, "https://") {
|
||||||
|
protocol = "https"
|
||||||
|
hostPort = strings.TrimPrefix(exposedAddress, "https://")
|
||||||
|
} else if strings.HasPrefix(exposedAddress, "http://") {
|
||||||
|
protocol = "http"
|
||||||
|
hostPort = strings.TrimPrefix(exposedAddress, "http://")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract hostname (without port)
|
||||||
|
hostname = hostPort
|
||||||
|
if host, _, err := net.SplitHostPort(hostPort); err == nil {
|
||||||
|
hostname = host
|
||||||
|
}
|
||||||
|
|
||||||
|
return protocol, hostname, hostPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
|
||||||
|
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
|
||||||
|
// overrides are configured (server.signalUri, server.relays, server.stuns).
|
||||||
|
func (c *CombinedConfig) ApplySimplifiedDefaults() {
|
||||||
|
if !c.hasRequiredSettings() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse exposed address to extract protocol and hostname
|
||||||
|
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
|
||||||
|
|
||||||
|
// Check for external service overrides
|
||||||
|
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
|
||||||
|
hasExternalSignal := c.Server.SignalURI != ""
|
||||||
|
hasExternalStuns := len(c.Server.Stuns) > 0
|
||||||
|
|
||||||
|
// Default stunPorts to [3478] if not specified and no external STUN
|
||||||
|
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
|
||||||
|
c.Server.StunPorts = []int{3478}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
|
||||||
|
c.applySignalDefaults(hasExternalSignal)
|
||||||
|
c.applyManagementDefaults(exposedHost)
|
||||||
|
|
||||||
|
// Auto-configure client settings (stuns, relays, signalUri)
|
||||||
|
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyRelayDefaults configures the relay service if no external relay is configured.
|
||||||
|
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
|
||||||
|
if hasExternalRelay {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Relay.Enabled = true
|
||||||
|
relayProto := "rel"
|
||||||
|
if exposedProto == "https" {
|
||||||
|
relayProto = "rels"
|
||||||
|
}
|
||||||
|
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
|
||||||
|
c.Relay.AuthSecret = c.Server.AuthSecret
|
||||||
|
if c.Relay.LogLevel == "" {
|
||||||
|
c.Relay.LogLevel = c.Server.LogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable local STUN only if no external STUN servers and stunPorts are configured
|
||||||
|
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
|
||||||
|
c.Relay.Stun.Enabled = true
|
||||||
|
c.Relay.Stun.Ports = c.Server.StunPorts
|
||||||
|
if c.Relay.Stun.LogLevel == "" {
|
||||||
|
c.Relay.Stun.LogLevel = c.Server.LogLevel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applySignalDefaults configures the signal service if no external signal is configured.
|
||||||
|
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
|
||||||
|
if hasExternalSignal {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Signal.Enabled = true
|
||||||
|
if c.Signal.LogLevel == "" {
|
||||||
|
c.Signal.LogLevel = c.Server.LogLevel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyManagementDefaults configures the management service (always enabled).
|
||||||
|
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
|
||||||
|
c.Management.Enabled = true
|
||||||
|
if c.Management.LogLevel == "" {
|
||||||
|
c.Management.LogLevel = c.Server.LogLevel
|
||||||
|
}
|
||||||
|
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
|
||||||
|
c.Management.DataDir = c.Server.DataDir
|
||||||
|
}
|
||||||
|
c.Management.DnsDomain = exposedHost
|
||||||
|
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
|
||||||
|
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
|
||||||
|
// Copy auth config from server if management auth issuer is not set
|
||||||
|
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
|
||||||
|
c.Management.Auth = c.Server.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy store config from server if not set
|
||||||
|
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
|
||||||
|
if c.Server.Store.Engine != "" {
|
||||||
|
c.Management.Store = c.Server.Store
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy reverse proxy config from server
|
||||||
|
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
|
||||||
|
c.Management.ReverseProxy = c.Server.ReverseProxy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
|
||||||
|
// External overrides from server config take precedence over auto-generated values
|
||||||
|
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
|
||||||
|
// Determine relay protocol from exposed protocol
|
||||||
|
relayProto := "rel"
|
||||||
|
if exposedProto == "https" {
|
||||||
|
relayProto = "rels"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure STUN servers for clients
|
||||||
|
if hasExternalStuns {
|
||||||
|
// Use external STUN servers from server config
|
||||||
|
c.Management.Stuns = c.Server.Stuns
|
||||||
|
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
|
||||||
|
// Auto-configure local STUN servers for all ports
|
||||||
|
for _, port := range c.Server.StunPorts {
|
||||||
|
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
|
||||||
|
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure relay for clients
|
||||||
|
if hasExternalRelay {
|
||||||
|
// Use external relay config from server
|
||||||
|
c.Management.Relays = c.Server.Relays
|
||||||
|
} else if len(c.Management.Relays.Addresses) == 0 {
|
||||||
|
// Auto-configure local relay
|
||||||
|
c.Management.Relays.Addresses = []string{
|
||||||
|
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Management.Relays.Secret == "" {
|
||||||
|
c.Management.Relays.Secret = c.Server.AuthSecret
|
||||||
|
}
|
||||||
|
if c.Management.Relays.CredentialsTTL == "" {
|
||||||
|
c.Management.Relays.CredentialsTTL = "12h"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure signal for clients
|
||||||
|
if hasExternalSignal {
|
||||||
|
// Use external signal URI from server config
|
||||||
|
c.Management.SignalURI = c.Server.SignalURI
|
||||||
|
} else if c.Management.SignalURI == "" {
|
||||||
|
// Auto-configure local signal
|
||||||
|
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads configuration from a YAML file
|
||||||
|
func LoadConfig(configPath string) (*CombinedConfig, error) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
if configPath == "" {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate internal configs from server settings
|
||||||
|
cfg.ApplySimplifiedDefaults()
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the configuration
|
||||||
|
func (c *CombinedConfig) Validate() error {
|
||||||
|
if c.Server.ExposedAddress == "" {
|
||||||
|
return fmt.Errorf("server.exposedAddress is required")
|
||||||
|
}
|
||||||
|
if c.Server.DataDir == "" {
|
||||||
|
return fmt.Errorf("server.dataDir is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate STUN ports
|
||||||
|
seen := make(map[int]bool)
|
||||||
|
for _, port := range c.Server.StunPorts {
|
||||||
|
if port <= 0 || port > 65535 {
|
||||||
|
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
|
||||||
|
}
|
||||||
|
if seen[port] {
|
||||||
|
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
|
||||||
|
}
|
||||||
|
seen[port] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// authSecret is required only if running local relay (no external relay configured)
|
||||||
|
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
|
||||||
|
if !hasExternalRelay && c.Server.AuthSecret == "" {
|
||||||
|
return fmt.Errorf("server.authSecret is required when running local relay")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasTLSCert returns true if TLS certificate files are configured
|
||||||
|
func (c *CombinedConfig) HasTLSCert() bool {
|
||||||
|
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasLetsEncrypt returns true if Let's Encrypt is configured
|
||||||
|
func (c *CombinedConfig) HasLetsEncrypt() bool {
|
||||||
|
return c.Server.TLS.LetsEncrypt.Enabled &&
|
||||||
|
c.Server.TLS.LetsEncrypt.DataDir != "" &&
|
||||||
|
len(c.Server.TLS.LetsEncrypt.Domains) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
|
||||||
|
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
|
||||||
|
switch strings.ToLower(proto) {
|
||||||
|
case "udp":
|
||||||
|
return nbconfig.UDP, true
|
||||||
|
case "dtls":
|
||||||
|
return nbconfig.DTLS, true
|
||||||
|
case "tcp":
|
||||||
|
return nbconfig.TCP, true
|
||||||
|
case "http":
|
||||||
|
return nbconfig.HTTP, true
|
||||||
|
case "https":
|
||||||
|
return nbconfig.HTTPS, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseStunProtocol determines protocol for STUN/TURN servers.
|
||||||
|
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
|
||||||
|
// Explicit proto overrides URI scheme. Defaults to UDP.
|
||||||
|
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
|
||||||
|
if proto != "" {
|
||||||
|
if p, ok := parseExplicitProtocol(proto); ok {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = strings.ToLower(uri)
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(uri, "stuns:"):
|
||||||
|
return nbconfig.DTLS
|
||||||
|
case strings.HasPrefix(uri, "turns:"):
|
||||||
|
return nbconfig.DTLS
|
||||||
|
default:
|
||||||
|
// stun:, turn:, or no scheme - default to UDP
|
||||||
|
return nbconfig.UDP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSignalProtocol determines protocol for Signal servers.
|
||||||
|
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
|
||||||
|
func parseSignalProtocol(uri string) nbconfig.Protocol {
|
||||||
|
uri = strings.ToLower(uri)
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(uri, "http://"):
|
||||||
|
return nbconfig.HTTP
|
||||||
|
default:
|
||||||
|
// https:// or no scheme - default to HTTPS
|
||||||
|
return nbconfig.HTTPS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripSignalProtocol removes the protocol prefix from a signal URI.
|
||||||
|
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
|
||||||
|
func stripSignalProtocol(uri string) string {
|
||||||
|
uri = strings.TrimPrefix(uri, "https://")
|
||||||
|
uri = strings.TrimPrefix(uri, "http://")
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToManagementConfig converts CombinedConfig to management server config
|
||||||
|
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||||
|
mgmt := c.Management
|
||||||
|
|
||||||
|
// Build STUN hosts
|
||||||
|
var stuns []*nbconfig.Host
|
||||||
|
for _, s := range mgmt.Stuns {
|
||||||
|
stuns = append(stuns, &nbconfig.Host{
|
||||||
|
URI: s.URI,
|
||||||
|
Proto: parseStunProtocol(s.URI, s.Proto),
|
||||||
|
Username: s.Username,
|
||||||
|
Password: s.Password,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build relay config
|
||||||
|
var relayConfig *nbconfig.Relay
|
||||||
|
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
|
||||||
|
var ttl time.Duration
|
||||||
|
if mgmt.Relays.CredentialsTTL != "" {
|
||||||
|
var err error
|
||||||
|
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
relayConfig = &nbconfig.Relay{
|
||||||
|
Addresses: mgmt.Relays.Addresses,
|
||||||
|
CredentialsTTL: util.Duration{Duration: ttl},
|
||||||
|
Secret: mgmt.Relays.Secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build signal config
|
||||||
|
var signalConfig *nbconfig.Host
|
||||||
|
if mgmt.SignalURI != "" {
|
||||||
|
signalConfig = &nbconfig.Host{
|
||||||
|
URI: stripSignalProtocol(mgmt.SignalURI),
|
||||||
|
Proto: parseSignalProtocol(mgmt.SignalURI),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build store config
|
||||||
|
storeConfig := nbconfig.StoreConfig{
|
||||||
|
Engine: types.Engine(mgmt.Store.Engine),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build reverse proxy config
|
||||||
|
reverseProxy := nbconfig.ReverseProxy{
|
||||||
|
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
|
||||||
|
}
|
||||||
|
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
|
||||||
|
if prefix, err := netip.ParsePrefix(p); err == nil {
|
||||||
|
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range mgmt.ReverseProxy.TrustedPeers {
|
||||||
|
if prefix, err := netip.ParsePrefix(p); err == nil {
|
||||||
|
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build HTTP config (required, even if empty)
|
||||||
|
httpConfig := &nbconfig.HttpServerConfig{}
|
||||||
|
|
||||||
|
// Build embedded IDP config (always enabled in combined server)
|
||||||
|
storageFile := mgmt.Auth.Storage.File
|
||||||
|
if storageFile == "" {
|
||||||
|
storageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddedIdP := &idp.EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: mgmt.Auth.Issuer,
|
||||||
|
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||||
|
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||||
|
Storage: idp.EmbeddedStorageConfig{
|
||||||
|
Type: mgmt.Auth.Storage.Type,
|
||||||
|
Config: idp.EmbeddedStorageTypeConfig{
|
||||||
|
File: storageFile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||||
|
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||||
|
embeddedIdP.Owner = &idp.OwnerConfig{
|
||||||
|
Email: mgmt.Auth.Owner.Email,
|
||||||
|
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set HTTP config fields for embedded IDP
|
||||||
|
httpConfig.AuthIssuer = mgmt.Auth.Issuer
|
||||||
|
httpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
httpConfig.AuthClientID = httpConfig.AuthAudience
|
||||||
|
httpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
|
httpConfig.AuthUserIDClaim = "sub"
|
||||||
|
httpConfig.AuthKeysLocation = mgmt.Auth.Issuer + "/keys"
|
||||||
|
httpConfig.OIDCConfigEndpoint = mgmt.Auth.Issuer + "/.well-known/openid-configuration"
|
||||||
|
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
|
||||||
|
callbackURL := strings.TrimSuffix(httpConfig.AuthIssuer, "/oauth2")
|
||||||
|
httpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||||
|
|
||||||
|
return &nbconfig.Config{
|
||||||
|
Stuns: stuns,
|
||||||
|
Relay: relayConfig,
|
||||||
|
Signal: signalConfig,
|
||||||
|
Datadir: mgmt.DataDir,
|
||||||
|
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
|
||||||
|
HttpConfig: httpConfig,
|
||||||
|
StoreConfig: storeConfig,
|
||||||
|
ReverseProxy: reverseProxy,
|
||||||
|
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
|
||||||
|
EmbeddedIdP: embeddedIdP,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
|
||||||
|
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
|
||||||
|
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
|
||||||
|
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedded IdP requires single account mode
|
||||||
|
if disableSingleAccMode {
|
||||||
|
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set LocalAddress for embedded IdP, used for internal JWT validation
|
||||||
|
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
|
||||||
|
|
||||||
|
// Set storage defaults based on Datadir
|
||||||
|
if cfg.EmbeddedIdP.Storage.Type == "" {
|
||||||
|
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
|
||||||
|
}
|
||||||
|
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
|
||||||
|
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer := cfg.EmbeddedIdP.Issuer
|
||||||
|
|
||||||
|
// Ensure HttpConfig exists
|
||||||
|
if cfg.HttpConfig == nil {
|
||||||
|
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set HttpConfig values from EmbeddedIdP
|
||||||
|
cfg.HttpConfig.AuthIssuer = issuer
|
||||||
|
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
|
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||||
|
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||||
|
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||||
|
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureEncryptionKey generates an encryption key if not set.
|
||||||
|
// Unlike management server, we don't write back to the config file.
|
||||||
|
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
|
||||||
|
if cfg.DataStoreEncryptionKey != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
|
||||||
|
key, err := crypt.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
|
||||||
|
}
|
||||||
|
cfg.DataStoreEncryptionKey = key
|
||||||
|
keyPreview := key[:8] + "..."
|
||||||
|
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogConfigInfo logs informational messages about the loaded configuration
|
||||||
|
func LogConfigInfo(cfg *nbconfig.Config) {
|
||||||
|
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
|
||||||
|
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
|
||||||
|
}
|
||||||
|
if cfg.Relay != nil {
|
||||||
|
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
33
combined/cmd/pprof.go
Normal file
33
combined/cmd/pprof.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
//go:build pprof
|
||||||
|
// +build pprof
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
addr := pprofAddr()
|
||||||
|
go pprof(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pprofAddr() string {
|
||||||
|
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||||
|
if listenAddr == "" {
|
||||||
|
return "localhost:6969"
|
||||||
|
}
|
||||||
|
|
||||||
|
return listenAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func pprof(listenAddr string) {
|
||||||
|
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||||
|
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||||
|
log.Fatalf("Failed to start pprof: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
715
combined/cmd/root.go
Normal file
715
combined/cmd/root.go
Normal file
@@ -0,0 +1,715 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/encryption"
|
||||||
|
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
|
relayServer "github.com/netbirdio/netbird/relay/server"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
|
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
|
||||||
|
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"github.com/netbirdio/netbird/stun"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
|
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
configPath string
|
||||||
|
config *CombinedConfig
|
||||||
|
|
||||||
|
rootCmd = &cobra.Command{
|
||||||
|
Use: "combined",
|
||||||
|
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
|
||||||
|
Long: `Combined Netbird server for self-hosted deployments.
|
||||||
|
|
||||||
|
All services (Management, Signal, Relay) are multiplexed on a single port.
|
||||||
|
Optional STUN server runs on separate UDP ports.
|
||||||
|
|
||||||
|
Configuration is loaded from a YAML file specified with --config.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
SilenceErrors: true,
|
||||||
|
RunE: execute,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||||
|
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||||
|
|
||||||
|
rootCmd.AddCommand(newTokenCommands())
|
||||||
|
}
|
||||||
|
|
||||||
|
func Execute() error {
|
||||||
|
return rootCmd.Execute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForExitSignal() {
|
||||||
|
osSigs := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-osSigs
|
||||||
|
}
|
||||||
|
|
||||||
|
func execute(cmd *cobra.Command, _ []string) error {
|
||||||
|
if err := initializeConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Management is required as the base server when signal or relay are enabled
|
||||||
|
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
|
||||||
|
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
|
||||||
|
}
|
||||||
|
|
||||||
|
servers, err := createAllServers(cmd.Context(), config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register services with management's gRPC server using AfterInit hook
|
||||||
|
setupServerHooks(servers, config)
|
||||||
|
|
||||||
|
// Start management server (this also starts the HTTP listener)
|
||||||
|
if servers.mgmtSrv != nil {
|
||||||
|
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
|
||||||
|
cleanupSTUNListeners(servers.stunListeners)
|
||||||
|
return fmt.Errorf("failed to start management server: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all other servers
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
|
||||||
|
|
||||||
|
waitForExitSignal()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
|
||||||
|
wg.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeConfig loads and validates the configuration, then initializes logging.
|
||||||
|
func initializeConfig() error {
|
||||||
|
var err error
|
||||||
|
config, err = LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("invalid config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsn := config.Server.Store.DSN; dsn != "" {
|
||||||
|
switch strings.ToLower(config.Server.Store.Engine) {
|
||||||
|
case "postgres":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||||
|
case "mysql":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Starting combined NetBird server")
|
||||||
|
logConfig(config)
|
||||||
|
logEnvVars()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverInstances holds all server instances created during startup.
|
||||||
|
type serverInstances struct {
|
||||||
|
relaySrv *relayServer.Server
|
||||||
|
mgmtSrv *mgmtServer.BaseServer
|
||||||
|
signalSrv *signalServer.Server
|
||||||
|
healthcheck *healthcheck.Server
|
||||||
|
stunServer *stun.Server
|
||||||
|
stunListeners []*net.UDPConn
|
||||||
|
metricsServer *sharedMetrics.Metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAllServers creates all server instances based on configuration.
|
||||||
|
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
|
||||||
|
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create metrics server: %w", err)
|
||||||
|
}
|
||||||
|
servers := &serverInstances{
|
||||||
|
metricsServer: metricsServer,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, tlsSupport, err := handleTLSConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := servers.createManagementServer(ctx, cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := servers.createSignalServer(ctx, cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := servers.createHealthcheckServer(cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return servers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
|
||||||
|
if !cfg.Relay.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.stunListeners, err = createSTUNListeners(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
|
||||||
|
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||||
|
|
||||||
|
relayCfg := relayServer.Config{
|
||||||
|
Meter: s.metricsServer.Meter,
|
||||||
|
ExposedAddress: cfg.Relay.ExposedAddress,
|
||||||
|
AuthValidator: authenticator,
|
||||||
|
TLSSupport: tlsSupport,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Relay server created")
|
||||||
|
|
||||||
|
if len(s.stunListeners) > 0 {
|
||||||
|
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
|
||||||
|
if !cfg.Management.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgmtConfig, err := cfg.ToManagementConfig()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create management config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||||
|
if portErr != nil {
|
||||||
|
portStr = "443"
|
||||||
|
}
|
||||||
|
mgmtPort, _ := strconv.Atoi(portStr)
|
||||||
|
|
||||||
|
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
|
||||||
|
cleanupSTUNListeners(s.stunListeners)
|
||||||
|
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
|
||||||
|
cleanupSTUNListeners(s.stunListeners)
|
||||||
|
return fmt.Errorf("failed to ensure encryption key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
LogConfigInfo(mgmtConfig)
|
||||||
|
|
||||||
|
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(s.stunListeners)
|
||||||
|
return fmt.Errorf("failed to create management server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject externally-managed AppMetrics so management uses the shared metrics server
|
||||||
|
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(s.stunListeners)
|
||||||
|
return fmt.Errorf("failed to create management app metrics: %w", err)
|
||||||
|
}
|
||||||
|
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
|
||||||
|
|
||||||
|
log.Infof("Management server created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
|
||||||
|
if !cfg.Signal.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(s.stunListeners)
|
||||||
|
return fmt.Errorf("failed to create signal server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Signal server created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
|
||||||
|
hCfg := healthcheck.Config{
|
||||||
|
ListenAddress: cfg.Server.HealthcheckAddress,
|
||||||
|
ServiceChecker: s.relaySrv,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupServerHooks registers services with management's gRPC server.
|
||||||
|
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||||
|
if servers.mgmtSrv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
|
||||||
|
grpcSrv := s.GRPCServer()
|
||||||
|
|
||||||
|
if servers.signalSrv != nil {
|
||||||
|
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
|
||||||
|
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||||
|
if servers.relaySrv != nil {
|
||||||
|
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
|
||||||
|
if srv != nil {
|
||||||
|
instanceURL := srv.InstanceURL()
|
||||||
|
log.Infof("Relay server instance URL: %s", instanceURL.String())
|
||||||
|
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||||
|
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("failed to start metrics server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if stunServer != nil {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := stunServer.Listen(); err != nil {
|
||||||
|
if errors.Is(err, stun.ErrServerClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("STUN server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
|
||||||
|
var errs error
|
||||||
|
|
||||||
|
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if stunServer != nil {
|
||||||
|
if err := stunServer.Shutdown(); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv != nil {
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgmtSrv != nil {
|
||||||
|
log.Infof("shutting down management and signal servers")
|
||||||
|
if err := mgmtSrv.Stop(); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if metricsServer != nil {
|
||||||
|
log.Infof("shutting down metrics server")
|
||||||
|
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||||
|
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
|
||||||
|
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
|
||||||
|
}
|
||||||
|
return httpHealthcheck, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
|
||||||
|
srv, err := relayServer.NewServer(cfg)
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
return nil, fmt.Errorf("failed to create relay server: %w", err)
|
||||||
|
}
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
|
||||||
|
for _, l := range stunListeners {
|
||||||
|
_ = l.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
|
||||||
|
var stunListeners []*net.UDPConn
|
||||||
|
if cfg.Relay.Stun.Enabled {
|
||||||
|
for _, port := range cfg.Relay.Stun.Ports {
|
||||||
|
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||||
|
if err != nil {
|
||||||
|
cleanupSTUNListeners(stunListeners)
|
||||||
|
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
|
||||||
|
}
|
||||||
|
stunListeners = append(stunListeners, listener)
|
||||||
|
log.Infof("STUN server listening on UDP port %d", port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stunListeners, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
||||||
|
tlsCfg := cfg.Server.TLS
|
||||||
|
|
||||||
|
if tlsCfg.LetsEncrypt.AWSRoute53 {
|
||||||
|
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
|
||||||
|
r53 := encryption.Route53TLS{
|
||||||
|
DataDir: tlsCfg.LetsEncrypt.DataDir,
|
||||||
|
Email: tlsCfg.LetsEncrypt.Email,
|
||||||
|
Domains: tlsCfg.LetsEncrypt.Domains,
|
||||||
|
}
|
||||||
|
tc, err := r53.GetCertificate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
return tc, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HasLetsEncrypt() {
|
||||||
|
log.Infof("setting up TLS with Let's Encrypt")
|
||||||
|
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
|
||||||
|
}
|
||||||
|
return certManager.TLSConfig(), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HasTLSCert() {
|
||||||
|
log.Debugf("using file based TLS config")
|
||||||
|
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
return tc, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||||
|
mgmt := cfg.Management
|
||||||
|
|
||||||
|
dnsDomain := mgmt.DnsDomain
|
||||||
|
singleAccModeDomain := dnsDomain
|
||||||
|
|
||||||
|
// Extract port from listen address
|
||||||
|
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||||
|
if err != nil {
|
||||||
|
// If no port specified, assume default
|
||||||
|
portStr = "443"
|
||||||
|
}
|
||||||
|
mgmtPort, _ := strconv.Atoi(portStr)
|
||||||
|
|
||||||
|
mgmtSrv := mgmtServer.NewServer(
|
||||||
|
&mgmtServer.Config{
|
||||||
|
NbConfig: mgmtConfig,
|
||||||
|
DNSDomain: dnsDomain,
|
||||||
|
MgmtSingleAccModeDomain: singleAccModeDomain,
|
||||||
|
MgmtPort: mgmtPort,
|
||||||
|
MgmtMetricsPort: cfg.Server.MetricsPort,
|
||||||
|
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
||||||
|
DisableGeoliteUpdate: mgmt.DisableGeoliteUpdate,
|
||||||
|
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
|
||||||
|
UserDeleteFromIDPEnabled: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return mgmtSrv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||||
|
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||||
|
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
|
var relayAcceptFn func(conn net.Conn)
|
||||||
|
if relaySrv != nil {
|
||||||
|
relayAcceptFn = relaySrv.RelayAccept()
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
// Native gRPC traffic (HTTP/2 with gRPC content-type)
|
||||||
|
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
|
||||||
|
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
|
||||||
|
grpcServer.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
// WebSocket proxy for Management gRPC
|
||||||
|
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||||
|
wsProxy.Handler().ServeHTTP(w, r)
|
||||||
|
|
||||||
|
// WebSocket proxy for Signal gRPC
|
||||||
|
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
|
||||||
|
if cfg.Signal.Enabled {
|
||||||
|
wsProxy.Handler().ServeHTTP(w, r)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Signal service not enabled", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relay WebSocket
|
||||||
|
case r.URL.Path == "/relay":
|
||||||
|
if relayAcceptFn != nil {
|
||||||
|
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Management HTTP API (default)
|
||||||
|
default:
|
||||||
|
httpHandler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
|
||||||
|
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
|
||||||
|
acceptOptions := &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: []string{"*"},
|
||||||
|
}
|
||||||
|
|
||||||
|
wsConn, err := websocket.Accept(w, r, acceptOptions)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to accept relay ws connection: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connRemoteAddr := r.RemoteAddr
|
||||||
|
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
|
||||||
|
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
|
||||||
|
}
|
||||||
|
|
||||||
|
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
|
||||||
|
if err != nil {
|
||||||
|
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Relay WS client connected from: %s", rAddr)
|
||||||
|
|
||||||
|
conn := ws.NewConn(wsConn, lAddr, rAddr)
|
||||||
|
acceptFn(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// logConfig prints all configuration parameters for debugging
|
||||||
|
func logConfig(cfg *CombinedConfig) {
|
||||||
|
log.Info("=== Configuration ===")
|
||||||
|
logServerConfig(cfg)
|
||||||
|
logComponentsConfig(cfg)
|
||||||
|
logRelayConfig(cfg)
|
||||||
|
logManagementConfig(cfg)
|
||||||
|
log.Info("=== End Configuration ===")
|
||||||
|
}
|
||||||
|
|
||||||
|
func logServerConfig(cfg *CombinedConfig) {
|
||||||
|
log.Info("--- Server ---")
|
||||||
|
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
|
||||||
|
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
|
||||||
|
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
|
||||||
|
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
|
||||||
|
log.Infof(" Log level: %s", cfg.Server.LogLevel)
|
||||||
|
log.Infof(" Data dir: %s", cfg.Server.DataDir)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case cfg.HasTLSCert():
|
||||||
|
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
|
||||||
|
case cfg.HasLetsEncrypt():
|
||||||
|
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
|
||||||
|
default:
|
||||||
|
log.Info(" TLS: disabled (using reverse proxy)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logComponentsConfig(cfg *CombinedConfig) {
|
||||||
|
log.Info("--- Components ---")
|
||||||
|
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
|
||||||
|
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
|
||||||
|
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logRelayConfig(cfg *CombinedConfig) {
|
||||||
|
if !cfg.Relay.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info("--- Relay ---")
|
||||||
|
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
|
||||||
|
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
|
||||||
|
if cfg.Relay.Stun.Enabled {
|
||||||
|
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
|
||||||
|
} else {
|
||||||
|
log.Info(" STUN: disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logManagementConfig(cfg *CombinedConfig) {
|
||||||
|
if !cfg.Management.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info("--- Management ---")
|
||||||
|
log.Infof(" Data dir: %s", cfg.Management.DataDir)
|
||||||
|
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
|
||||||
|
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
|
||||||
|
if cfg.Server.Store.DSN != "" {
|
||||||
|
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info(" Auth (embedded IdP):")
|
||||||
|
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
|
||||||
|
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
|
||||||
|
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
|
||||||
|
|
||||||
|
log.Info(" Client settings:")
|
||||||
|
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
|
||||||
|
for _, s := range cfg.Management.Stuns {
|
||||||
|
log.Infof(" STUN: %s", s.URI)
|
||||||
|
}
|
||||||
|
if len(cfg.Management.Relays.Addresses) > 0 {
|
||||||
|
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
|
||||||
|
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// logEnvVars logs all NB_ environment variables that are currently set
|
||||||
|
func logEnvVars() {
|
||||||
|
log.Info("=== Environment Variables ===")
|
||||||
|
found := false
|
||||||
|
for _, env := range os.Environ() {
|
||||||
|
if strings.HasPrefix(env, "NB_") {
|
||||||
|
key, _, _ := strings.Cut(env, "=")
|
||||||
|
value := os.Getenv(key)
|
||||||
|
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
|
||||||
|
value = maskSecret(value)
|
||||||
|
}
|
||||||
|
log.Infof(" %s=%s", key, value)
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
log.Info(" (none set)")
|
||||||
|
}
|
||||||
|
log.Info("=== End Environment Variables ===")
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskDSNPassword masks the password in a DSN string.
|
||||||
|
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
|
||||||
|
func maskDSNPassword(dsn string) string {
|
||||||
|
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
|
||||||
|
if strings.Contains(dsn, "password=") {
|
||||||
|
parts := strings.Fields(dsn)
|
||||||
|
for i, p := range parts {
|
||||||
|
if strings.HasPrefix(p, "password=") {
|
||||||
|
parts[i] = "password=****"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// URI format: "user:password@host..."
|
||||||
|
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
|
||||||
|
prefix := dsn[:atIdx]
|
||||||
|
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
|
||||||
|
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskSecret returns first 4 chars of secret followed by "..."
|
||||||
|
func maskSecret(secret string) string {
|
||||||
|
if len(secret) <= 4 {
|
||||||
|
return "****"
|
||||||
|
}
|
||||||
|
return secret[:4] + "..."
|
||||||
|
}
|
||||||
60
combined/cmd/token.go
Normal file
60
combined/cmd/token.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||||
|
func newTokenCommands() *cobra.Command {
|
||||||
|
return tokencmd.NewCommands(withTokenStore)
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||||
|
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
cfg, err := LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||||
|
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||||
|
case "postgres":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||||
|
case "mysql":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := cfg.Management.DataDir
|
||||||
|
engine := types.Engine(cfg.Management.Store.Engine)
|
||||||
|
|
||||||
|
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := s.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, s)
|
||||||
|
}
|
||||||
111
combined/config.yaml.example
Normal file
111
combined/config.yaml.example
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# NetBird Combined Server Configuration
|
||||||
|
# Copy this file to config.yaml and customize for your deployment
|
||||||
|
#
|
||||||
|
# This is a Management server with optional embedded Signal, Relay, and STUN services.
|
||||||
|
# By default, all services run locally. You can use external services instead by
|
||||||
|
# setting the corresponding override fields.
|
||||||
|
#
|
||||||
|
# Architecture:
|
||||||
|
# - Management: Always runs locally (this IS the management server)
|
||||||
|
# - Signal: Local by default; set 'signalUri' to use external (disables local)
|
||||||
|
# - Relay: Local by default; set 'relays' to use external (disables local)
|
||||||
|
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
|
||||||
|
|
||||||
|
server:
|
||||||
|
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
|
||||||
|
listenAddress: ":443"
|
||||||
|
|
||||||
|
# Public address that peers will use to connect to this server
|
||||||
|
# Used for relay connections and management DNS domain
|
||||||
|
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
|
||||||
|
exposedAddress: "https://server.mycompany.com:443"
|
||||||
|
|
||||||
|
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
|
||||||
|
# stunPorts:
|
||||||
|
# - 3478
|
||||||
|
|
||||||
|
# Metrics endpoint port
|
||||||
|
metricsPort: 9090
|
||||||
|
|
||||||
|
# Healthcheck endpoint address
|
||||||
|
healthcheckAddress: ":9000"
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
|
||||||
|
logFile: "console" # "console" or path to log file
|
||||||
|
|
||||||
|
# TLS configuration (optional)
|
||||||
|
tls:
|
||||||
|
certFile: ""
|
||||||
|
keyFile: ""
|
||||||
|
letsencrypt:
|
||||||
|
enabled: false
|
||||||
|
dataDir: ""
|
||||||
|
domains: []
|
||||||
|
email: ""
|
||||||
|
awsRoute53: false
|
||||||
|
|
||||||
|
# Shared secret for relay authentication (required when running local relay)
|
||||||
|
authSecret: "your-secret-key-here"
|
||||||
|
|
||||||
|
# Data directory for all services
|
||||||
|
dataDir: "/var/lib/netbird/"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# External Service Overrides (optional)
|
||||||
|
# Use these to point to external Signal, Relay, or STUN servers instead of
|
||||||
|
# running them locally. When set, the corresponding local service is disabled.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# External STUN servers - disables local STUN server
|
||||||
|
# stuns:
|
||||||
|
# - uri: "stun:stun.example.com:3478"
|
||||||
|
# - uri: "stun:stun.example.com:3479"
|
||||||
|
|
||||||
|
# External relay servers - disables local relay server
|
||||||
|
# relays:
|
||||||
|
# addresses:
|
||||||
|
# - "rels://relay.example.com:443"
|
||||||
|
# credentialsTTL: "12h"
|
||||||
|
# secret: "relay-shared-secret"
|
||||||
|
|
||||||
|
# External signal server - disables local signal server
|
||||||
|
# signalUri: "https://signal.example.com:443"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Management Settings
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Metrics and updates
|
||||||
|
disableAnonymousMetrics: false
|
||||||
|
disableGeoliteUpdate: false
|
||||||
|
|
||||||
|
# Embedded authentication/identity provider (Dex) configuration (always enabled)
|
||||||
|
auth:
|
||||||
|
# OIDC issuer URL - must be publicly accessible
|
||||||
|
issuer: "https://example.com/oauth2"
|
||||||
|
localAuthDisabled: false
|
||||||
|
signKeyRefreshEnabled: false
|
||||||
|
# OAuth2 redirect URIs for dashboard
|
||||||
|
dashboardRedirectURIs:
|
||||||
|
- "https://app.example.com/nb-auth"
|
||||||
|
- "https://app.example.com/nb-silent-auth"
|
||||||
|
# OAuth2 redirect URIs for CLI
|
||||||
|
cliRedirectURIs:
|
||||||
|
- "http://localhost:53000/"
|
||||||
|
# Optional initial admin user
|
||||||
|
# owner:
|
||||||
|
# email: "admin@example.com"
|
||||||
|
# password: "initial-password"
|
||||||
|
|
||||||
|
# Store configuration
|
||||||
|
store:
|
||||||
|
engine: "sqlite" # sqlite, postgres, or mysql
|
||||||
|
dsn: "" # Connection string for postgres or mysql
|
||||||
|
encryptionKey: ""
|
||||||
|
|
||||||
|
# Reverse proxy settings (optional)
|
||||||
|
# reverseProxy:
|
||||||
|
# trustedHTTPProxies: []
|
||||||
|
# trustedHTTPProxiesCount: 0
|
||||||
|
# trustedPeers: []
|
||||||
13
combined/main.go
Normal file
13
combined/main.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/combined/cmd"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
log.Fatalf("failed to execute command: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
package txt
|
package txt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/levels"
|
"github.com/netbirdio/netbird/formatter/levels"
|
||||||
@@ -18,7 +16,7 @@ type TextFormatter struct {
|
|||||||
func NewTextFormatter() *TextFormatter {
|
func NewTextFormatter() *TextFormatter {
|
||||||
return &TextFormatter{
|
return &TextFormatter{
|
||||||
levelDesc: levels.ValidLevelDesc,
|
levelDesc: levels.ValidLevelDesc,
|
||||||
timestampFormat: time.RFC3339, // or RFC3339
|
timestampFormat: "2006-01-02T15:04:05.000Z07:00",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,6 @@ func TestLogTextFormat(t *testing.T) {
|
|||||||
result, _ := formatter.Format(someEntry)
|
result, _ := formatter.Format(someEntry)
|
||||||
|
|
||||||
parsedString := string(result)
|
parsedString := string(result)
|
||||||
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
expectedString := "^2021-02-21T01:10:30.000Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||||
assert.Regexp(t, expectedString, parsedString)
|
assert.Regexp(t, expectedString, parsedString)
|
||||||
}
|
}
|
||||||
|
|||||||
26
go.mod
26
go.mod
@@ -40,8 +40,9 @@ require (
|
|||||||
github.com/c-robinson/iplib v1.0.3
|
github.com/c-robinson/iplib v1.0.3
|
||||||
github.com/caddyserver/certmagic v0.21.3
|
github.com/caddyserver/certmagic v0.21.3
|
||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coder/websocket v1.8.13
|
github.com/coder/websocket v1.8.14
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
|
github.com/coreos/go-oidc/v3 v3.14.1
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||||
github.com/dexidp/dex/api/v2 v2.4.0
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
@@ -68,7 +69,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/oapi-codegen/runtime v1.1.2
|
github.com/oapi-codegen/runtime v1.1.2
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
@@ -82,6 +83,7 @@ require (
|
|||||||
github.com/pion/stun/v3 v3.1.0
|
github.com/pion/stun/v3 v3.1.0
|
||||||
github.com/pion/transport/v3 v3.1.1
|
github.com/pion/transport/v3 v3.1.1
|
||||||
github.com/pion/turn/v3 v3.0.1
|
github.com/pion/turn/v3 v3.0.1
|
||||||
|
github.com/pires/go-proxyproto v0.11.0
|
||||||
github.com/pkg/sftp v1.13.9
|
github.com/pkg/sftp v1.13.9
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/quic-go/quic-go v0.55.0
|
github.com/quic-go/quic-go v0.55.0
|
||||||
@@ -91,10 +93,10 @@ require (
|
|||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/testcontainers/testcontainers-go v0.31.0
|
github.com/testcontainers/testcontainers-go v0.37.0
|
||||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/mysql v0.37.0
|
||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0
|
||||||
github.com/testcontainers/testcontainers-go/modules/redis v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/redis v0.37.0
|
||||||
github.com/things-go/go-socks5 v0.0.4
|
github.com/things-go/go-socks5 v0.0.4
|
||||||
github.com/ti-mo/conntrack v0.5.1
|
github.com/ti-mo/conntrack v0.5.1
|
||||||
github.com/ti-mo/netfilter v0.5.2
|
github.com/ti-mo/netfilter v0.5.2
|
||||||
@@ -140,7 +142,6 @@ require (
|
|||||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||||
github.com/awnumar/memcall v0.4.0 // indirect
|
github.com/awnumar/memcall v0.4.0 // indirect
|
||||||
@@ -164,17 +165,16 @@ require (
|
|||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/containerd/containerd v1.7.29 // indirect
|
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/coreos/go-oidc/v3 v3.14.1 // indirect
|
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/docker v26.1.5+incompatible // indirect
|
github.com/docker/docker v28.0.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.5.0 // indirect
|
github.com/docker/go-connections v0.5.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/ebitengine/purego v0.8.2 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fredbi/uri v1.1.1 // indirect
|
github.com/fredbi/uri v1.1.1 // indirect
|
||||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||||
@@ -220,9 +220,10 @@ require (
|
|||||||
github.com/lib/pq v1.10.9 // indirect
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
github.com/libdns/libdns v0.2.2 // indirect
|
github.com/libdns/libdns v0.2.2 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||||
github.com/magiconair/properties v1.8.7 // indirect
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
|
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||||
|
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
||||||
github.com/mholt/acmez/v2 v2.0.1 // indirect
|
github.com/mholt/acmez/v2 v2.0.1 // indirect
|
||||||
@@ -241,7 +242,7 @@ require (
|
|||||||
github.com/nxadm/tail v1.4.8 // indirect
|
github.com/nxadm/tail v1.4.8 // indirect
|
||||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||||
@@ -255,6 +256,7 @@ require (
|
|||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
github.com/russellhaering/goxmldsig v1.5.0 // indirect
|
github.com/russellhaering/goxmldsig v1.5.0 // indirect
|
||||||
github.com/rymdport/portal v0.4.2 // indirect
|
github.com/rymdport/portal v0.4.2 // indirect
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
|
||||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||||
github.com/shopspring/decimal v1.4.0 // indirect
|
github.com/shopspring/decimal v1.4.0 // indirect
|
||||||
github.com/spf13/cast v1.7.0 // indirect
|
github.com/spf13/cast v1.7.0 // indirect
|
||||||
|
|||||||
55
go.sum
55
go.sum
@@ -33,8 +33,6 @@ github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe
|
|||||||
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
||||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
|
|
||||||
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
|
|
||||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
|
||||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
|
||||||
@@ -107,10 +105,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
|
|||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
|
||||||
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
|
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||||
@@ -135,12 +131,14 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g=
|
github.com/docker/docker v28.0.1+incompatible h1:FCHjSRdXhNRFjlHMTv4jUNlIBbTeRjrWfeFuJp7jpo0=
|
||||||
github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
github.com/docker/docker v28.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I=
|
||||||
|
github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
github.com/eko/gocache/lib/v4 v4.2.0 h1:MNykyi5Xw+5Wu3+PUrvtOCaKSZM1nUSVftbzmeC7Yuw=
|
github.com/eko/gocache/lib/v4 v4.2.0 h1:MNykyi5Xw+5Wu3+PUrvtOCaKSZM1nUSVftbzmeC7Yuw=
|
||||||
github.com/eko/gocache/lib/v4 v4.2.0/go.mod h1:7ViVmbU+CzDHzRpmB4SXKyyzyuJ8A3UW3/cszpcqB4M=
|
github.com/eko/gocache/lib/v4 v4.2.0/go.mod h1:7ViVmbU+CzDHzRpmB4SXKyyzyuJ8A3UW3/cszpcqB4M=
|
||||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEMl+NivxtR5R3/hw=
|
github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEMl+NivxtR5R3/hw=
|
||||||
@@ -195,8 +193,6 @@ github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3yg
|
|||||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
||||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
|
||||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
@@ -357,13 +353,15 @@ github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/Y
|
|||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI=
|
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI=
|
||||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
|
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||||
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
|
||||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
|
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
|
||||||
@@ -406,8 +404,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
@@ -437,13 +435,12 @@ github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
|||||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8=
|
|
||||||
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
|
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||||
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
|
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
|
||||||
@@ -474,6 +471,8 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
|||||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||||
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
|
||||||
|
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
|
||||||
|
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
@@ -511,6 +510,8 @@ github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU
|
|||||||
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
||||||
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
|
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
|
||||||
github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8=
|
github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI=
|
||||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||||
@@ -552,14 +553,14 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
|||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U=
|
github.com/testcontainers/testcontainers-go v0.37.0 h1:L2Qc0vkTw2EHWQ08djon0D2uw7Z/PtHS/QzZZ5Ra/hg=
|
||||||
github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI=
|
github.com/testcontainers/testcontainers-go v0.37.0/go.mod h1:QPzbxZhQ6Bclip9igjLFj6z0hs01bU8lrl2dHQmgFGM=
|
||||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk=
|
github.com/testcontainers/testcontainers-go/modules/mysql v0.37.0 h1:LqUos1oR5iuuzorFnSvxsHNdYdCHB/DfI82CuT58wbI=
|
||||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0/go.mod h1:REFmO+lSG9S6uSBEwIMZCxeI36uhScjTwChYADeO3JA=
|
github.com/testcontainers/testcontainers-go/modules/mysql v0.37.0/go.mod h1:vHEEHx5Kf+uq5hveaVAMrTzPY8eeRZcKcl23MRw5Tkc=
|
||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3KNKRbJMbWv+wolWqOFUECmjYZ+sIRZCIBc/E=
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0 h1:hsVwFkS6s+79MbKEO+W7A1wNIw1fmkMtF4fg83m6kbc=
|
||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw=
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0/go.mod h1:Qj/eGbRbO/rEYdcRLmN+bEojzatP/+NS1y8ojl2PQsc=
|
||||||
github.com/testcontainers/testcontainers-go/modules/redis v0.31.0 h1:5X6GhOdLwV86zcW8sxppJAMtsDC9u+r9tb3biBc9GKs=
|
github.com/testcontainers/testcontainers-go/modules/redis v0.37.0 h1:9HIY28I9ME/Zmb+zey1p/I1mto5+5ch0wLX+nJdOsQ4=
|
||||||
github.com/testcontainers/testcontainers-go/modules/redis v0.31.0/go.mod h1:dKi5xBwy1k4u8yb3saQHu7hMEJwewHXxzbcMAuLiA6o=
|
github.com/testcontainers/testcontainers-go/modules/redis v0.37.0/go.mod h1:Abu9g/25Qv+FkYVx3U4Voaynou1c+7D0HIhaQJXvk6E=
|
||||||
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
|
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
|
||||||
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
|
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
|
||||||
github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0=
|
github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0=
|
||||||
@@ -849,7 +850,7 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
|
|||||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
|
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
|
||||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=
|
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=
|
||||||
|
|||||||
@@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
|
||||||
|
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||||
|
connectors, err := p.storage.ListConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to list connectors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
|
||||||
|
for _, conn := range connectors {
|
||||||
|
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
|
||||||
|
if conn.ID != "local" || conn.Type != "local" {
|
||||||
|
p.logger.Info("found non-local connector", "id", conn.ID)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.logger.Info("no non-local connectors found")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableLocalAuth removes the local (password) connector.
|
||||||
|
// Returns an error if no other connectors are configured.
|
||||||
|
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
|
||||||
|
hasOthers, err := p.HasNonLocalConnectors(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !hasOthers {
|
||||||
|
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if local connector exists
|
||||||
|
_, err = p.storage.GetConnector(ctx, "local")
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
// Already disabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check local connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the local connector
|
||||||
|
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete local connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("local authentication disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
|
||||||
|
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
|
||||||
|
return ensureLocalConnector(ctx, p.storage)
|
||||||
|
}
|
||||||
|
|
||||||
// ensureStaticConnectors creates or updates static connectors in storage
|
// ensureStaticConnectors creates or updates static connectors in storage
|
||||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||||
for _, conn := range connectors {
|
for _, conn := range connectors {
|
||||||
|
|||||||
@@ -99,15 +99,16 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
|||||||
|
|
||||||
// Build Dex server config - use Dex's types directly
|
// Build Dex server config - use Dex's types directly
|
||||||
dexConfig := server.Config{
|
dexConfig := server.Config{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
Storage: stor,
|
Storage: stor,
|
||||||
SkipApprovalScreen: true,
|
SkipApprovalScreen: true,
|
||||||
SupportedResponseTypes: []string{"code"},
|
SupportedResponseTypes: []string{"code"},
|
||||||
Logger: logger,
|
ContinueOnConnectorFailure: true,
|
||||||
PrometheusRegistry: prometheus.NewRegistry(),
|
Logger: logger,
|
||||||
RotateKeysAfter: 6 * time.Hour,
|
PrometheusRegistry: prometheus.NewRegistry(),
|
||||||
IDTokensValidFor: 24 * time.Hour,
|
RotateKeysAfter: 6 * time.Hour,
|
||||||
RefreshTokenPolicy: refreshPolicy,
|
IDTokensValidFor: 24 * time.Hour,
|
||||||
|
RefreshTokenPolicy: refreshPolicy,
|
||||||
Web: server.WebConfig{
|
Web: server.WebConfig{
|
||||||
Issuer: "NetBird",
|
Issuer: "NetBird",
|
||||||
},
|
},
|
||||||
@@ -260,6 +261,7 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
|
|||||||
if len(cfg.SupportedResponseTypes) == 0 {
|
if len(cfg.SupportedResponseTypes) == 0 {
|
||||||
cfg.SupportedResponseTypes = []string{"code"}
|
cfg.SupportedResponseTypes = []string{"code"}
|
||||||
}
|
}
|
||||||
|
cfg.ContinueOnConnectorFailure = true
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dex
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -195,3 +196,64 @@ enablePasswordDB: true
|
|||||||
|
|
||||||
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
Port: 5556,
|
||||||
|
DataDir: tmpDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := NewProvider(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = provider.Stop(ctx) }()
|
||||||
|
|
||||||
|
// The provider should have started successfully even though
|
||||||
|
// ContinueOnConnectorFailure is an internal Dex config field.
|
||||||
|
// We verify the provider is functional by performing a basic operation.
|
||||||
|
assert.NotNil(t, provider.dexServer)
|
||||||
|
assert.NotNil(t, provider.storage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
yamlContent := `
|
||||||
|
issuer: http://localhost:5556/dex
|
||||||
|
storage:
|
||||||
|
type: sqlite3
|
||||||
|
config:
|
||||||
|
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||||
|
web:
|
||||||
|
http: 127.0.0.1:5556
|
||||||
|
enablePasswordDB: true
|
||||||
|
`
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
yamlConfig, err := LoadConfig(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer stor.Close()
|
||||||
|
|
||||||
|
err = initializeStorage(ctx, stor, yamlConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
cfg := buildDexConfig(yamlConfig, stor, logger)
|
||||||
|
|
||||||
|
assert.True(t, cfg.ContinueOnConnectorFailure,
|
||||||
|
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
1286
infrastructure_files/migrate.sh
Executable file
1286
infrastructure_files/migrate.sh
Executable file
File diff suppressed because it is too large
Load Diff
17
management/Dockerfile.multistage
Normal file
17
management/Dockerfile.multistage
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||||
|
CMD ["--log-file", "console"]
|
||||||
|
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt
|
||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
"github.com/netbirdio/netbird/management/internals/server"
|
"github.com/netbirdio/netbird/management/internals/server"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -27,11 +29,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/crypt"
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
|
var newServer = func(cfg *server.Config) server.Server {
|
||||||
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
return server.NewServer(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
|
func SetNewServer(fn func(*server.Config) server.Server) {
|
||||||
newServer = fn
|
newServer = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +57,7 @@ var (
|
|||||||
// detect whether user specified a port
|
// detect whether user specified a port
|
||||||
userPort := cmd.Flag("port").Changed
|
userPort := cmd.Flag("port").Changed
|
||||||
|
|
||||||
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
|
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
|
||||||
}
|
}
|
||||||
@@ -108,7 +110,17 @@ var (
|
|||||||
mgmtSingleAccModeDomain = ""
|
mgmtSingleAccModeDomain = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
srv := newServer(&server.Config{
|
||||||
|
NbConfig: config,
|
||||||
|
DNSDomain: dnsDomain,
|
||||||
|
MgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||||
|
MgmtPort: mgmtPort,
|
||||||
|
MgmtMetricsPort: mgmtMetricsPort,
|
||||||
|
DisableLegacyManagementPort: disableLegacyManagementPort,
|
||||||
|
DisableMetrics: disableMetrics,
|
||||||
|
DisableGeoliteUpdate: disableGeoliteUpdate,
|
||||||
|
UserDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||||
|
})
|
||||||
go func() {
|
go func() {
|
||||||
if err := srv.Start(cmd.Context()); err != nil {
|
if err := srv.Start(cmd.Context()); err != nil {
|
||||||
log.Fatalf("Server error: %v", err)
|
log.Fatalf("Server error: %v", err)
|
||||||
@@ -133,35 +145,35 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
|
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
|
||||||
loadedConfig := &nbconfig.Config{}
|
loadedConfig := &nbconfig.Config{}
|
||||||
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
|
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
applyCommandLineOverrides(loadedConfig)
|
ApplyCommandLineOverrides(loadedConfig)
|
||||||
|
|
||||||
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
||||||
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
|
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
|
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logConfigInfo(loadedConfig)
|
LogConfigInfo(loadedConfig)
|
||||||
|
|
||||||
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
|
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return loadedConfig, nil
|
return loadedConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyCommandLineOverrides applies command-line flag overrides to the config
|
// ApplyCommandLineOverrides applies command-line flag overrides to the config
|
||||||
func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||||
if mgmtLetsencryptDomain != "" {
|
if mgmtLetsencryptDomain != "" {
|
||||||
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
|
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
|
||||||
}
|
}
|
||||||
@@ -174,9 +186,9 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||||
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
||||||
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -213,17 +225,20 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
|||||||
// Set HttpConfig values from EmbeddedIdP
|
// Set HttpConfig values from EmbeddedIdP
|
||||||
cfg.HttpConfig.AuthIssuer = issuer
|
cfg.HttpConfig.AuthIssuer = issuer
|
||||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||||
|
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
|
||||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||||
|
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
|
||||||
|
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||||
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||||
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
||||||
if oidcEndpoint == "" {
|
if oidcEndpoint == "" {
|
||||||
return nil
|
return nil
|
||||||
@@ -249,16 +264,16 @@ func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
|||||||
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
|
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
|
||||||
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||||
|
|
||||||
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
|
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
|
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
|
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
|
||||||
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
|
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
|
||||||
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
|
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -285,8 +300,8 @@ func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
|
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
|
||||||
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
|
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
|
||||||
if cfg.PKCEAuthorizationFlow == nil {
|
if cfg.PKCEAuthorizationFlow == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -299,8 +314,8 @@ func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
|
|||||||
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
// logConfigInfo logs informational messages about the loaded configuration
|
// LogConfigInfo logs informational messages about the loaded configuration
|
||||||
func logConfigInfo(cfg *nbconfig.Config) {
|
func LogConfigInfo(cfg *nbconfig.Config) {
|
||||||
if cfg.EmbeddedIdP != nil {
|
if cfg.EmbeddedIdP != nil {
|
||||||
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
|
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
|
||||||
}
|
}
|
||||||
@@ -309,8 +324,8 @@ func logConfigInfo(cfg *nbconfig.Config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
|
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
|
||||||
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
|
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
|
||||||
if cfg.DataStoreEncryptionKey != "" {
|
if cfg.DataStoreEncryptionKey != "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
|
|||||||
t.Fatalf("failed to create config: %s", err)
|
t.Fatalf("failed to create config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
|
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load management config: %s", err)
|
t.Fatalf("failed to load management config: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,21 +16,22 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
dnsDomain string
|
dnsDomain string
|
||||||
mgmtDataDir string
|
mgmtDataDir string
|
||||||
logLevel string
|
logLevel string
|
||||||
logFile string
|
logFile string
|
||||||
disableMetrics bool
|
disableMetrics bool
|
||||||
disableSingleAccMode bool
|
disableSingleAccMode bool
|
||||||
disableGeoliteUpdate bool
|
disableGeoliteUpdate bool
|
||||||
idpSignKeyRefreshEnabled bool
|
idpSignKeyRefreshEnabled bool
|
||||||
userDeleteFromIDPEnabled bool
|
userDeleteFromIDPEnabled bool
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtLetsencryptDomain string
|
disableLegacyManagementPort bool
|
||||||
mgmtSingleAccModeDomain string
|
mgmtLetsencryptDomain string
|
||||||
certFile string
|
mgmtSingleAccModeDomain string
|
||||||
certKey string
|
certFile string
|
||||||
|
certKey string
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird-mgmt",
|
Use: "netbird-mgmt",
|
||||||
@@ -55,6 +56,7 @@ func Execute() error {
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||||
|
mgmtCmd.Flags().BoolVar(&disableLegacyManagementPort, "disable-legacy-port", false, "disabling the old legacy port (33073)")
|
||||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||||
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||||
@@ -80,4 +82,8 @@ func init() {
|
|||||||
migrationCmd.AddCommand(upCmd)
|
migrationCmd.AddCommand(upCmd)
|
||||||
|
|
||||||
rootCmd.AddCommand(migrationCmd)
|
rootCmd.AddCommand(migrationCmd)
|
||||||
|
|
||||||
|
tc := newTokenCommands()
|
||||||
|
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||||
|
rootCmd.AddCommand(tc)
|
||||||
}
|
}
|
||||||
|
|||||||
55
management/cmd/token.go
Normal file
55
management/cmd/token.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tokenDatadir string
|
||||||
|
|
||||||
|
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||||
|
func newTokenCommands() *cobra.Command {
|
||||||
|
cmd := tokencmd.NewCommands(withTokenStore)
|
||||||
|
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||||
|
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := config.Datadir
|
||||||
|
if tokenDatadir != "" {
|
||||||
|
datadir = tokenDatadir
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := s.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, s)
|
||||||
|
}
|
||||||
185
management/cmd/token/token.go
Normal file
185
management/cmd/token/token.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
// Package tokencmd provides reusable cobra commands for managing proxy access tokens.
|
||||||
|
// Both the management and combined binaries use these commands, each providing
|
||||||
|
// their own StoreOpener to handle config loading and store initialization.
|
||||||
|
package tokencmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StoreOpener initializes a store from the command context and calls fn.
|
||||||
|
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
|
||||||
|
|
||||||
|
// NewCommands creates the token command tree with the given store opener.
|
||||||
|
// Returns the parent "token" command with create, list, and revoke subcommands.
|
||||||
|
func NewCommands(opener StoreOpener) *cobra.Command {
|
||||||
|
var (
|
||||||
|
tokenName string
|
||||||
|
tokenExpireIn string
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenCmd := &cobra.Command{
|
||||||
|
Use: "token",
|
||||||
|
Short: "Manage proxy access tokens",
|
||||||
|
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||||
|
}
|
||||||
|
|
||||||
|
createCmd := &cobra.Command{
|
||||||
|
Use: "create",
|
||||||
|
Short: "Create a new proxy access token",
|
||||||
|
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
return runCreate(ctx, s, cmd.OutOrStdout(), tokenName, tokenExpireIn)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
createCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||||
|
createCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||||
|
if err := createCmd.MarkFlagRequired("name"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listCmd := &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List all proxy access tokens",
|
||||||
|
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
return runList(ctx, s, cmd.OutOrStdout())
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
revokeCmd := &cobra.Command{
|
||||||
|
Use: "revoke [token-id]",
|
||||||
|
Short: "Revoke a proxy access token",
|
||||||
|
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||||
|
return runRevoke(ctx, s, cmd.OutOrStdout(), args[0])
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCmd.AddCommand(createCmd, listCmd, revokeCmd)
|
||||||
|
return tokenCmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCreate(ctx context.Context, s store.Store, w io.Writer, name string, expireIn string) error {
|
||||||
|
expiresIn, err := ParseDuration(expireIn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse expiration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated, err := types.CreateNewProxyAccessToken(name, expiresIn, nil, "CLI")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||||
|
return fmt.Errorf("save token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(w, "Token created successfully!")
|
||||||
|
_, _ = fmt.Fprintf(w, "Token: %s\n", generated.PlainToken)
|
||||||
|
_, _ = fmt.Fprintln(w)
|
||||||
|
_, _ = fmt.Fprintln(w, "IMPORTANT: Save this token now. It will not be shown again.")
|
||||||
|
_, _ = fmt.Fprintf(w, "Token ID: %s\n", generated.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runList(ctx context.Context, s store.Store, out io.Writer) error {
|
||||||
|
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
_, _ = fmt.Fprintln(out, "No proxy access tokens found.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
|
||||||
|
_, _ = fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||||
|
_, _ = fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
expires := "never"
|
||||||
|
if t.ExpiresAt != nil {
|
||||||
|
expires = t.ExpiresAt.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
|
||||||
|
lastUsed := "never"
|
||||||
|
if t.LastUsed != nil {
|
||||||
|
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||||
|
}
|
||||||
|
|
||||||
|
revoked := "no"
|
||||||
|
if t.Revoked {
|
||||||
|
revoked = "yes"
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||||
|
t.ID,
|
||||||
|
t.Name,
|
||||||
|
t.CreatedAt.Format("2006-01-02"),
|
||||||
|
expires,
|
||||||
|
lastUsed,
|
||||||
|
revoked,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Flush()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runRevoke(ctx context.Context, s store.Store, w io.Writer, tokenID string) error {
|
||||||
|
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||||
|
return fmt.Errorf("revoke token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "Token %s revoked successfully.\n", tokenID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||||
|
// An empty string returns zero duration (no expiration).
|
||||||
|
func ParseDuration(s string) (time.Duration, error) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s[len(s)-1] == 'd' {
|
||||||
|
d, err := strconv.Atoi(s[:len(s)-1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return time.Duration(d) * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
101
management/cmd/token/token_test.go
Normal file
101
management/cmd/token/token_test.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package tokencmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected time.Duration
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string returns zero",
|
||||||
|
input: "",
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "days suffix",
|
||||||
|
input: "30d",
|
||||||
|
expected: 30 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one day",
|
||||||
|
input: "1d",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "365 days",
|
||||||
|
input: "365d",
|
||||||
|
expected: 365 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hours via Go duration",
|
||||||
|
input: "24h",
|
||||||
|
expected: 24 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "minutes via Go duration",
|
||||||
|
input: "30m",
|
||||||
|
expected: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex Go duration",
|
||||||
|
input: "1h30m",
|
||||||
|
expected: 90 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid day format",
|
||||||
|
input: "abcd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative days",
|
||||||
|
input: "-1d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero days",
|
||||||
|
input: "0d",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-numeric days",
|
||||||
|
input: "xyzd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative Go duration",
|
||||||
|
input: "-24h",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero Go duration",
|
||||||
|
input: "0s",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid Go duration",
|
||||||
|
input: "notaduration",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := ParseDuration(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -174,6 +174,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
semaphore := make(chan struct{}, 10)
|
semaphore := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -247,7 +248,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +327,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
dnsCache := &cache.DNSConfigCache{}
|
dnsCache := &cache.DNSConfigCache{}
|
||||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
@@ -370,7 +375,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||||
|
|
||||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||||
|
Update: update,
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -435,6 +443,8 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
|
||||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
@@ -778,6 +788,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
@@ -840,6 +851,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
} else {
|
} else {
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|||||||
@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
|
|||||||
func TestSendUpdate(t *testing.T) {
|
func TestSendUpdate(t *testing.T) {
|
||||||
peer := "test-sendupdate"
|
peer := "test-sendupdate"
|
||||||
peersUpdater := NewPeersUpdateManager(nil)
|
peersUpdater := NewPeersUpdateManager(nil)
|
||||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update1 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 0,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
t.Error("Error creating the channel")
|
t.Error("Error creating the channel")
|
||||||
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
|
|||||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||||
}
|
}
|
||||||
|
|
||||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
update2 := &network_map.UpdateMessage{
|
||||||
NetworkMap: &proto.NetworkMap{
|
Update: &proto.SyncResponse{
|
||||||
Serial: 10,
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}}
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
|
}
|
||||||
|
|
||||||
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
peersUpdater.SendUpdate(context.Background(), peer, update2)
|
||||||
timeout := time.After(5 * time.Second)
|
timeout := time.After(5 * time.Second)
|
||||||
|
|||||||
@@ -4,6 +4,19 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MessageType indicates the type of update message for debouncing strategy
|
||||||
|
type MessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
|
||||||
|
// These updates can be safely debounced - only the latest state matters
|
||||||
|
MessageTypeNetworkMap MessageType = iota
|
||||||
|
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
|
||||||
|
// These updates should not be dropped as they contain time-sensitive information
|
||||||
|
MessageTypeControlConfig
|
||||||
|
)
|
||||||
|
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
|
MessageType MessageType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -32,6 +33,7 @@ type Manager interface {
|
|||||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||||
SetAccountManager(accountManager account.Manager)
|
SetAccountManager(accountManager account.Manager)
|
||||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||||
|
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
@@ -182,3 +184,36 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||||
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
|
||||||
|
if err == nil && existingPeerID != "" {
|
||||||
|
// Peer already exists
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("proxy-%s", xid.New().String())
|
||||||
|
peer := &peer.Peer{
|
||||||
|
Ephemeral: true,
|
||||||
|
ProxyMeta: peer.ProxyMeta{
|
||||||
|
Cluster: cluster,
|
||||||
|
Embedded: true,
|
||||||
|
},
|
||||||
|
Name: name,
|
||||||
|
Key: peerKey,
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
Meta: peer.PeerSystemMeta{
|
||||||
|
Hostname: name,
|
||||||
|
GoOS: "proxy",
|
||||||
|
OS: "proxy",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -162,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer mocks base method.
|
||||||
|
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||||
|
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccessLogEntry struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
ServiceID string `gorm:"index"`
|
||||||
|
Timestamp time.Time `gorm:"index"`
|
||||||
|
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||||
|
Method string `gorm:"index"`
|
||||||
|
Host string `gorm:"index"`
|
||||||
|
Path string `gorm:"index"`
|
||||||
|
Duration time.Duration `gorm:"index"`
|
||||||
|
StatusCode int `gorm:"index"`
|
||||||
|
Reason string
|
||||||
|
UserId string `gorm:"index"`
|
||||||
|
AuthMethodUsed string `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||||
|
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
||||||
|
a.ID = serviceLog.GetLogId()
|
||||||
|
a.ServiceID = serviceLog.GetServiceId()
|
||||||
|
a.Timestamp = serviceLog.GetTimestamp().AsTime()
|
||||||
|
a.Method = serviceLog.GetMethod()
|
||||||
|
a.Host = serviceLog.GetHost()
|
||||||
|
a.Path = serviceLog.GetPath()
|
||||||
|
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
|
||||||
|
a.StatusCode = int(serviceLog.GetResponseCode())
|
||||||
|
a.UserId = serviceLog.GetUserId()
|
||||||
|
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
||||||
|
a.AccountID = serviceLog.GetAccountId()
|
||||||
|
|
||||||
|
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||||
|
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
||||||
|
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceLog.GetAuthSuccess() {
|
||||||
|
a.Reason = "Authentication failed"
|
||||||
|
} else if serviceLog.GetResponseCode() >= 400 {
|
||||||
|
a.Reason = "Request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToAPIResponse converts an AccessLogEntry to the API ProxyAccessLog type
|
||||||
|
func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||||
|
var sourceIP *string
|
||||||
|
if a.GeoLocation.ConnectionIP != nil {
|
||||||
|
ip := a.GeoLocation.ConnectionIP.String()
|
||||||
|
sourceIP = &ip
|
||||||
|
}
|
||||||
|
|
||||||
|
var reason *string
|
||||||
|
if a.Reason != "" {
|
||||||
|
reason = &a.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID *string
|
||||||
|
if a.UserId != "" {
|
||||||
|
userID = &a.UserId
|
||||||
|
}
|
||||||
|
|
||||||
|
var authMethod *string
|
||||||
|
if a.AuthMethodUsed != "" {
|
||||||
|
authMethod = &a.AuthMethodUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
var countryCode *string
|
||||||
|
if a.GeoLocation.CountryCode != "" {
|
||||||
|
countryCode = &a.GeoLocation.CountryCode
|
||||||
|
}
|
||||||
|
|
||||||
|
var cityName *string
|
||||||
|
if a.GeoLocation.CityName != "" {
|
||||||
|
cityName = &a.GeoLocation.CityName
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.ProxyAccessLog{
|
||||||
|
Id: a.ID,
|
||||||
|
ServiceId: a.ServiceID,
|
||||||
|
Timestamp: a.Timestamp,
|
||||||
|
Method: a.Method,
|
||||||
|
Host: a.Host,
|
||||||
|
Path: a.Path,
|
||||||
|
DurationMs: int(a.Duration.Milliseconds()),
|
||||||
|
StatusCode: a.StatusCode,
|
||||||
|
SourceIp: sourceIP,
|
||||||
|
Reason: reason,
|
||||||
|
UserId: userID,
|
||||||
|
AuthMethodUsed: authMethod,
|
||||||
|
CountryCode: countryCode,
|
||||||
|
CityName: cityName,
|
||||||
|
}
|
||||||
|
}
|
||||||
178
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
178
management/internals/modules/reverseproxy/accesslogs/filter.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultPageSize is the default number of records per page
|
||||||
|
DefaultPageSize = 50
|
||||||
|
// MaxPageSize is the maximum number of records allowed per page
|
||||||
|
MaxPageSize = 100
|
||||||
|
|
||||||
|
// Default sorting
|
||||||
|
DefaultSortBy = "timestamp"
|
||||||
|
DefaultSortOrder = "desc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Valid sortable fields mapped to their database column names or expressions
|
||||||
|
// For multi-column sorts, columns are separated by comma (e.g., "host, path")
|
||||||
|
var validSortFields = map[string]string{
|
||||||
|
"timestamp": "timestamp",
|
||||||
|
"url": "host, path", // Sort by host first, then path
|
||||||
|
"host": "host",
|
||||||
|
"path": "path",
|
||||||
|
"method": "method",
|
||||||
|
"status_code": "status_code",
|
||||||
|
"duration": "duration",
|
||||||
|
"source_ip": "location_connection_ip",
|
||||||
|
"user_id": "user_id",
|
||||||
|
"auth_method": "auth_method_used",
|
||||||
|
"reason": "reason",
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessLogFilter holds pagination, filtering, and sorting parameters for access logs
|
||||||
|
type AccessLogFilter struct {
|
||||||
|
// Page is the current page number (1-indexed)
|
||||||
|
Page int
|
||||||
|
// PageSize is the number of records per page
|
||||||
|
PageSize int
|
||||||
|
|
||||||
|
// Sorting parameters
|
||||||
|
SortBy string // Field to sort by: timestamp, url, host, path, method, status_code, duration, source_ip, user_id, auth_method, reason
|
||||||
|
SortOrder string // Sort order: asc or desc (default: desc)
|
||||||
|
|
||||||
|
// Filtering parameters
|
||||||
|
Search *string // General search across log ID, host, path, source IP, and user fields
|
||||||
|
SourceIP *string // Filter by source IP address
|
||||||
|
Host *string // Filter by host header
|
||||||
|
Path *string // Filter by request path (supports LIKE pattern)
|
||||||
|
UserID *string // Filter by authenticated user ID
|
||||||
|
UserEmail *string // Filter by user email (requires user lookup)
|
||||||
|
UserName *string // Filter by user name (requires user lookup)
|
||||||
|
Method *string // Filter by HTTP method
|
||||||
|
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
|
||||||
|
StatusCode *int // Filter by HTTP status code
|
||||||
|
StartDate *time.Time // Filter by timestamp >= start_date
|
||||||
|
EndDate *time.Time // Filter by timestamp <= end_date
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFromRequest parses pagination, sorting, and filter parameters from HTTP request query parameters
|
||||||
|
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
|
||||||
|
queryParams := r.URL.Query()
|
||||||
|
|
||||||
|
f.Page = parsePositiveInt(queryParams.Get("page"), 1)
|
||||||
|
f.PageSize = min(parsePositiveInt(queryParams.Get("page_size"), DefaultPageSize), MaxPageSize)
|
||||||
|
|
||||||
|
f.SortBy = parseSortField(queryParams.Get("sort_by"))
|
||||||
|
f.SortOrder = parseSortOrder(queryParams.Get("sort_order"))
|
||||||
|
|
||||||
|
f.Search = parseOptionalString(queryParams.Get("search"))
|
||||||
|
f.SourceIP = parseOptionalString(queryParams.Get("source_ip"))
|
||||||
|
f.Host = parseOptionalString(queryParams.Get("host"))
|
||||||
|
f.Path = parseOptionalString(queryParams.Get("path"))
|
||||||
|
f.UserID = parseOptionalString(queryParams.Get("user_id"))
|
||||||
|
f.UserEmail = parseOptionalString(queryParams.Get("user_email"))
|
||||||
|
f.UserName = parseOptionalString(queryParams.Get("user_name"))
|
||||||
|
f.Method = parseOptionalString(queryParams.Get("method"))
|
||||||
|
f.Status = parseOptionalString(queryParams.Get("status"))
|
||||||
|
f.StatusCode = parseOptionalInt(queryParams.Get("status_code"))
|
||||||
|
f.StartDate = parseOptionalRFC3339(queryParams.Get("start_date"))
|
||||||
|
f.EndDate = parseOptionalRFC3339(queryParams.Get("end_date"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePositiveInt parses a positive integer from a string, returning defaultValue if invalid
|
||||||
|
func parsePositiveInt(s string, defaultValue int) int {
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptionalString returns a pointer to the string if non-empty, otherwise nil
|
||||||
|
func parseOptionalString(s string) *string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptionalInt parses an optional positive integer from a string
|
||||||
|
func parseOptionalInt(s string) *int {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if val, err := strconv.Atoi(s); err == nil && val > 0 {
|
||||||
|
v := val
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOptionalRFC3339 parses an optional RFC3339 timestamp from a string
|
||||||
|
func parseOptionalRFC3339(s string) *time.Time {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOffset calculates the database offset for pagination
|
||||||
|
func (f *AccessLogFilter) GetOffset() int {
|
||||||
|
return (f.Page - 1) * f.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit returns the page size for database queries
|
||||||
|
func (f *AccessLogFilter) GetLimit() int {
|
||||||
|
return f.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSortColumn returns the validated database column name for sorting
|
||||||
|
func (f *AccessLogFilter) GetSortColumn() string {
|
||||||
|
if column, ok := validSortFields[f.SortBy]; ok {
|
||||||
|
return column
|
||||||
|
}
|
||||||
|
return validSortFields[DefaultSortBy]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSortOrder returns the validated sort order (ASC or DESC)
|
||||||
|
func (f *AccessLogFilter) GetSortOrder() string {
|
||||||
|
if f.SortOrder == "asc" || f.SortOrder == "desc" {
|
||||||
|
return f.SortOrder
|
||||||
|
}
|
||||||
|
return DefaultSortOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSortField validates and returns the sort field, defaulting if invalid
|
||||||
|
func parseSortField(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return DefaultSortBy
|
||||||
|
}
|
||||||
|
// Check if the field is valid
|
||||||
|
if _, ok := validSortFields[s]; ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return DefaultSortBy
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSortOrder validates and returns the sort order, defaulting if invalid
|
||||||
|
func parseSortOrder(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return DefaultSortOrder
|
||||||
|
}
|
||||||
|
// Normalize to lowercase
|
||||||
|
s = strings.ToLower(s)
|
||||||
|
if s == "asc" || s == "desc" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return DefaultSortOrder
|
||||||
|
}
|
||||||
@@ -0,0 +1,570 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
expectedPage int
|
||||||
|
expectedPageSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default values when no params provided",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid page and page_size",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "25",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: 25,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page_size exceeds max, should cap at MaxPageSize",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "200",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: MaxPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "invalid",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "2",
|
||||||
|
"page_size": "invalid",
|
||||||
|
},
|
||||||
|
expectedPage: 2,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "0",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative page number, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "-1",
|
||||||
|
"page_size": "10",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero page_size, should use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"page": "1",
|
||||||
|
"page_size": "0",
|
||||||
|
},
|
||||||
|
expectedPage: 1,
|
||||||
|
expectedPageSize: DefaultPageSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
for key, value := range tt.queryParams {
|
||||||
|
q.Set(key, value)
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
|
||||||
|
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetOffset(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
expectedOffset int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first page",
|
||||||
|
page: 1,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second page",
|
||||||
|
page: 2,
|
||||||
|
pageSize: 50,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "third page with page size 25",
|
||||||
|
page: 3,
|
||||||
|
pageSize: 25,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "page 10 with page size 10",
|
||||||
|
page: 10,
|
||||||
|
pageSize: 10,
|
||||||
|
expectedOffset: 90,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: tt.page,
|
||||||
|
PageSize: tt.pageSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := filter.GetOffset()
|
||||||
|
assert.Equal(t, tt.expectedOffset, offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_GetLimit(t *testing.T) {
|
||||||
|
filter := &AccessLogFilter{
|
||||||
|
Page: 2,
|
||||||
|
PageSize: 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := filter.GetLimit()
|
||||||
|
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_FilterParams(t *testing.T) {
|
||||||
|
startDate := "2024-01-15T10:30:00Z"
|
||||||
|
endDate := "2024-01-16T15:45:00Z"
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("search", "test query")
|
||||||
|
q.Set("source_ip", "192.168.1.1")
|
||||||
|
q.Set("host", "example.com")
|
||||||
|
q.Set("path", "/api/users")
|
||||||
|
q.Set("user_id", "user123")
|
||||||
|
q.Set("user_email", "user@example.com")
|
||||||
|
q.Set("user_name", "John Doe")
|
||||||
|
q.Set("method", "GET")
|
||||||
|
q.Set("status", "success")
|
||||||
|
q.Set("status_code", "200")
|
||||||
|
q.Set("start_date", startDate)
|
||||||
|
q.Set("end_date", endDate)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Search)
|
||||||
|
assert.Equal(t, "test query", *filter.Search)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.SourceIP)
|
||||||
|
assert.Equal(t, "192.168.1.1", *filter.SourceIP)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Host)
|
||||||
|
assert.Equal(t, "example.com", *filter.Host)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Path)
|
||||||
|
assert.Equal(t, "/api/users", *filter.Path)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserID)
|
||||||
|
assert.Equal(t, "user123", *filter.UserID)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserEmail)
|
||||||
|
assert.Equal(t, "user@example.com", *filter.UserEmail)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.UserName)
|
||||||
|
assert.Equal(t, "John Doe", *filter.UserName)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Method)
|
||||||
|
assert.Equal(t, "GET", *filter.Method)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.Status)
|
||||||
|
assert.Equal(t, "success", *filter.Status)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.StatusCode)
|
||||||
|
assert.Equal(t, 200, *filter.StatusCode)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.StartDate)
|
||||||
|
expectedStart, _ := time.Parse(time.RFC3339, startDate)
|
||||||
|
assert.Equal(t, expectedStart, *filter.StartDate)
|
||||||
|
|
||||||
|
require.NotNil(t, filter.EndDate)
|
||||||
|
expectedEnd, _ := time.Parse(time.RFC3339, endDate)
|
||||||
|
assert.Equal(t, expectedEnd, *filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_EmptyFilters(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Nil(t, filter.Search)
|
||||||
|
assert.Nil(t, filter.SourceIP)
|
||||||
|
assert.Nil(t, filter.Host)
|
||||||
|
assert.Nil(t, filter.Path)
|
||||||
|
assert.Nil(t, filter.UserID)
|
||||||
|
assert.Nil(t, filter.UserEmail)
|
||||||
|
assert.Nil(t, filter.UserName)
|
||||||
|
assert.Nil(t, filter.Method)
|
||||||
|
assert.Nil(t, filter.Status)
|
||||||
|
assert.Nil(t, filter.StatusCode)
|
||||||
|
assert.Nil(t, filter.StartDate)
|
||||||
|
assert.Nil(t, filter.EndDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ParseFromRequest_InvalidFilters(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("status_code", "invalid")
|
||||||
|
q.Set("start_date", "not-a-date")
|
||||||
|
q.Set("end_date", "2024-99-99")
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Nil(t, filter.StatusCode, "invalid status_code should be nil")
|
||||||
|
assert.Nil(t, filter.StartDate, "invalid start_date should be nil")
|
||||||
|
assert.Nil(t, filter.EndDate, "invalid end_date should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePositiveInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
defaultValue int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"empty string", "", 10, 10},
|
||||||
|
{"valid positive int", "25", 10, 25},
|
||||||
|
{"zero", "0", 10, 10},
|
||||||
|
{"negative", "-5", 10, 10},
|
||||||
|
{"invalid string", "abc", 10, 10},
|
||||||
|
{"float", "3.14", 10, 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parsePositiveInt(tt.input, tt.defaultValue)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *string
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid string", "hello", strPtr("hello")},
|
||||||
|
{"whitespace", " ", strPtr(" ")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalString(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *int
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid positive int", "42", intPtr(42)},
|
||||||
|
{"zero", "0", nil},
|
||||||
|
{"negative", "-10", nil},
|
||||||
|
{"invalid string", "abc", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalInt(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionalRFC3339(t *testing.T) {
|
||||||
|
validDate := "2024-01-15T10:30:00Z"
|
||||||
|
expectedTime, _ := time.Parse(time.RFC3339, validDate)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected *time.Time
|
||||||
|
}{
|
||||||
|
{"empty string", "", nil},
|
||||||
|
{"valid RFC3339", validDate, &expectedTime},
|
||||||
|
{"invalid format", "2024-01-15", nil},
|
||||||
|
{"invalid date", "not-a-date", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseOptionalRFC3339(tt.input)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, *tt.expected, *result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_SortingDefaults(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, DefaultSortBy, filter.SortBy, "SortBy should default to timestamp")
|
||||||
|
assert.Equal(t, DefaultSortOrder, filter.SortOrder, "SortOrder should default to desc")
|
||||||
|
assert.Equal(t, "timestamp", filter.GetSortColumn(), "GetSortColumn should return timestamp")
|
||||||
|
assert.Equal(t, "desc", filter.GetSortOrder(), "GetSortOrder should return desc")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_ValidSortFields(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sortBy string
|
||||||
|
expectedColumn string
|
||||||
|
expectedSortByVal string
|
||||||
|
}{
|
||||||
|
{"timestamp", "timestamp", "timestamp", "timestamp"},
|
||||||
|
{"url", "url", "host, path", "url"},
|
||||||
|
{"host", "host", "host", "host"},
|
||||||
|
{"path", "path", "path", "path"},
|
||||||
|
{"method", "method", "method", "method"},
|
||||||
|
{"status_code", "status_code", "status_code", "status_code"},
|
||||||
|
{"duration", "duration", "duration", "duration"},
|
||||||
|
{"source_ip", "source_ip", "location_connection_ip", "source_ip"},
|
||||||
|
{"user_id", "user_id", "user_id", "user_id"},
|
||||||
|
{"auth_method", "auth_method", "auth_method_used", "auth_method"},
|
||||||
|
{"reason", "reason", "reason", "reason"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy, nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedSortByVal, filter.SortBy, "SortBy mismatch")
|
||||||
|
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn(), "GetSortColumn mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_InvalidSortField(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sortBy string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"invalid field", "invalid_field", DefaultSortBy},
|
||||||
|
{"empty field", "", DefaultSortBy},
|
||||||
|
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set("sort_by", tt.sortBy)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expected, filter.SortBy, "Invalid sort field should default to timestamp")
|
||||||
|
assert.Equal(t, validSortFields[DefaultSortBy], filter.GetSortColumn())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_SortOrder(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sortOrder string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"ascending", "asc", "asc"},
|
||||||
|
{"descending", "desc", "desc"},
|
||||||
|
{"uppercase ASC", "ASC", "asc"},
|
||||||
|
{"uppercase DESC", "DESC", "desc"},
|
||||||
|
{"mixed case Asc", "Asc", "asc"},
|
||||||
|
{"invalid order", "invalid", DefaultSortOrder},
|
||||||
|
{"empty order", "", DefaultSortOrder},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test?sort_order="+tt.sortOrder, nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expected, filter.GetSortOrder(), "GetSortOrder mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogFilter_CompleteSortingScenarios(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sortBy string
|
||||||
|
sortOrder string
|
||||||
|
expectedColumn string
|
||||||
|
expectedOrder string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "sort by host ascending",
|
||||||
|
sortBy: "host",
|
||||||
|
sortOrder: "asc",
|
||||||
|
expectedColumn: "host",
|
||||||
|
expectedOrder: "asc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sort by duration descending",
|
||||||
|
sortBy: "duration",
|
||||||
|
sortOrder: "desc",
|
||||||
|
expectedColumn: "duration",
|
||||||
|
expectedOrder: "desc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sort by status_code ascending",
|
||||||
|
sortBy: "status_code",
|
||||||
|
sortOrder: "asc",
|
||||||
|
expectedColumn: "status_code",
|
||||||
|
expectedOrder: "asc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid sort with valid order",
|
||||||
|
sortBy: "invalid",
|
||||||
|
sortOrder: "asc",
|
||||||
|
expectedColumn: "timestamp",
|
||||||
|
expectedOrder: "asc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid sort with invalid order",
|
||||||
|
sortBy: "method",
|
||||||
|
sortOrder: "invalid",
|
||||||
|
expectedColumn: "method",
|
||||||
|
expectedOrder: DefaultSortOrder,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test?sort_by="+tt.sortBy+"&sort_order="+tt.sortOrder, nil)
|
||||||
|
|
||||||
|
filter := &AccessLogFilter{}
|
||||||
|
filter.ParseFromRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedColumn, filter.GetSortColumn())
|
||||||
|
assert.Equal(t, tt.expectedOrder, filter.GetSortOrder())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSortField(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"valid field", "host", "host"},
|
||||||
|
{"empty string", "", DefaultSortBy},
|
||||||
|
{"invalid field", "invalid", DefaultSortBy},
|
||||||
|
{"malicious input", "timestamp--DROP", DefaultSortBy},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseSortField(tt.input)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSortOrder(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"asc lowercase", "asc", "asc"},
|
||||||
|
{"desc lowercase", "desc", "desc"},
|
||||||
|
{"ASC uppercase", "ASC", "asc"},
|
||||||
|
{"DESC uppercase", "DESC", "desc"},
|
||||||
|
{"invalid", "invalid", DefaultSortOrder},
|
||||||
|
{"empty", "", DefaultSortOrder},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseSortOrder(tt.input)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for creating pointers
|
||||||
|
func strPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func intPtr(i int) *int {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package accesslogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
|
||||||
|
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
|
||||||
|
CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error)
|
||||||
|
StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int)
|
||||||
|
StopPeriodicCleanup()
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager accesslogs.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var filter accesslogs.AccessLogFilter
|
||||||
|
filter.ParseFromRequest(r)
|
||||||
|
|
||||||
|
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiLogs := make([]api.ProxyAccessLog, 0, len(logs))
|
||||||
|
for _, log := range logs {
|
||||||
|
apiLogs = append(apiLogs, *log.ToAPIResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &api.ProxyAccessLogsResponse{
|
||||||
|
Data: apiLogs,
|
||||||
|
Page: filter.Page,
|
||||||
|
PageSize: filter.PageSize,
|
||||||
|
TotalRecords: int(totalCount),
|
||||||
|
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTotalPageCount calculates the total number of pages
|
||||||
|
func getTotalPageCount(totalCount, pageSize int) int {
|
||||||
|
if pageSize <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return (totalCount + pageSize - 1) / pageSize
|
||||||
|
}
|
||||||
@@ -0,0 +1,178 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
geo geolocation.Geolocation
|
||||||
|
cleanupCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
geo: geo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccessLog saves an access log entry to the database after enriching it
|
||||||
|
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
||||||
|
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
|
||||||
|
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to get location for access log source IP [%s]: %v", logEntry.GeoLocation.ConnectionIP.String(), err)
|
||||||
|
} else {
|
||||||
|
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
|
||||||
|
logEntry.GeoLocation.CityName = location.City.Names.En
|
||||||
|
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"service_id": logEntry.ServiceID,
|
||||||
|
"method": logEntry.Method,
|
||||||
|
"host": logEntry.Host,
|
||||||
|
"path": logEntry.Path,
|
||||||
|
"status": logEntry.StatusCode,
|
||||||
|
}).Errorf("failed to save access log: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
|
||||||
|
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, 0, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, totalCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupOldAccessLogs deletes access logs older than the specified retention period
|
||||||
|
func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||||
|
if retentionDays <= 0 {
|
||||||
|
log.WithContext(ctx).Debug("access log cleanup skipped: retention days is 0 or negative")
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoffTime := time.Now().AddDate(0, 0, -retentionDays)
|
||||||
|
deletedCount, err := m.store.DeleteOldAccessLogs(ctx, cutoffTime)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to cleanup old access logs: %v", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if deletedCount > 0 {
|
||||||
|
log.WithContext(ctx).Infof("cleaned up %d access logs older than %d days", deletedCount, retentionDays)
|
||||||
|
}
|
||||||
|
|
||||||
|
return deletedCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
|
||||||
|
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||||
|
if retentionDays <= 0 {
|
||||||
|
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cleanupIntervalHours <= 0 {
|
||||||
|
cleanupIntervalHours = 24
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupCtx, cancel := context.WithCancel(ctx)
|
||||||
|
m.cleanupCancel = cancel
|
||||||
|
|
||||||
|
cleanupInterval := time.Duration(cleanupIntervalHours) * time.Hour
|
||||||
|
ticker := time.NewTicker(cleanupInterval)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// Run cleanup immediately on startup
|
||||||
|
log.WithContext(cleanupCtx).Infof("starting access log cleanup routine (retention: %d days, interval: %d hours)", retentionDays, cleanupIntervalHours)
|
||||||
|
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||||
|
log.WithContext(cleanupCtx).Errorf("initial access log cleanup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-cleanupCtx.Done():
|
||||||
|
log.WithContext(cleanupCtx).Info("stopping access log cleanup routine")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if _, err := m.CleanupOldAccessLogs(cleanupCtx, retentionDays); err != nil {
|
||||||
|
log.WithContext(cleanupCtx).Errorf("periodic access log cleanup failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopPeriodicCleanup stops the periodic cleanup routine
|
||||||
|
func (m *managerImpl) StopPeriodicCleanup() {
|
||||||
|
if m.cleanupCancel != nil {
|
||||||
|
m.cleanupCancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveUserFilters converts user email/name filters to user ID filter
|
||||||
|
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||||
|
if filter.UserEmail == nil && filter.UserName == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var matchingUserIDs []string
|
||||||
|
for _, user := range users {
|
||||||
|
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
|
||||||
|
matchingUserIDs = append(matchingUserIDs, user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matchingUserIDs) > 0 {
|
||||||
|
filter.UserID = &matchingUserIDs[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,281 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCleanupOldAccessLogs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
retentionDays int
|
||||||
|
setupMock func(*store.MockStore)
|
||||||
|
expectedCount int64
|
||||||
|
expectedError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cleanup logs older than retention period",
|
||||||
|
retentionDays: 30,
|
||||||
|
setupMock: func(mockStore *store.MockStore) {
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||||
|
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||||
|
timeDiff := olderThan.Sub(expectedCutoff)
|
||||||
|
if timeDiff.Abs() > time.Second {
|
||||||
|
t.Errorf("cutoff time not as expected: got %v, want ~%v", olderThan, expectedCutoff)
|
||||||
|
}
|
||||||
|
return 5, nil
|
||||||
|
})
|
||||||
|
},
|
||||||
|
expectedCount: 5,
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no logs to cleanup",
|
||||||
|
retentionDays: 30,
|
||||||
|
setupMock: func(mockStore *store.MockStore) {
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
Return(int64(0), nil)
|
||||||
|
},
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero retention days skips cleanup",
|
||||||
|
retentionDays: 0,
|
||||||
|
setupMock: func(mockStore *store.MockStore) {
|
||||||
|
// No expectations - DeleteOldAccessLogs should not be called
|
||||||
|
},
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative retention days skips cleanup",
|
||||||
|
retentionDays: -10,
|
||||||
|
setupMock: func(mockStore *store.MockStore) {
|
||||||
|
// No expectations - DeleteOldAccessLogs should not be called
|
||||||
|
},
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
tt.setupMock(mockStore)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
deletedCount, err := manager.CleanupOldAccessLogs(ctx, tt.retentionDays)
|
||||||
|
|
||||||
|
if tt.expectedError {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expectedCount, deletedCount, "unexpected number of deleted logs")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupWithExactBoundary(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
DoAndReturn(func(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||||
|
expectedCutoff := time.Now().AddDate(0, 0, -30)
|
||||||
|
timeDiff := olderThan.Sub(expectedCutoff)
|
||||||
|
assert.Less(t, timeDiff.Abs(), time.Second, "cutoff time should be close to expected value")
|
||||||
|
return 1, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
deletedCount, err := manager.CleanupOldAccessLogs(ctx, 30)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), deletedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartPeriodicCleanup(t *testing.T) {
|
||||||
|
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
// No expectations - cleanup should not run
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
manager.StartPeriodicCleanup(ctx, 0, 1)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// If DeleteOldAccessLogs was called, the test will fail due to unexpected call
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("periodic cleanup runs immediately on start", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
Return(int64(2), nil).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Expectations verified by gomock on defer ctrl.Finish()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("periodic cleanup stops on context cancel", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
Return(int64(1), nil).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cleanup interval defaults to 24 hours when invalid", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
Return(int64(0), nil).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
manager.StartPeriodicCleanup(ctx, 30, 0)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
manager.StopPeriodicCleanup()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cleanup interval uses configured hours", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
Return(int64(3), nil).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
manager.StartPeriodicCleanup(ctx, 30, 12)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
manager.StopPeriodicCleanup()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopPeriodicCleanup(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
|
||||||
|
mockStore.EXPECT().
|
||||||
|
DeleteOldAccessLogs(gomock.Any(), gomock.Any()).
|
||||||
|
Return(int64(1), nil).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
manager := &managerImpl{
|
||||||
|
store: mockStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
manager.StartPeriodicCleanup(ctx, 30, 24)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
manager.StopPeriodicCleanup()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Expectations verified by gomock - would fail if more than 1 call happened
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopPeriodicCleanup_NotStarted(t *testing.T) {
|
||||||
|
manager := &managerImpl{}
|
||||||
|
|
||||||
|
// Should not panic if cleanup was never started
|
||||||
|
manager.StopPeriodicCleanup()
|
||||||
|
}
|
||||||
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
17
management/internals/modules/reverseproxy/domain/domain.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeFree Type = "free"
|
||||||
|
TypeCustom Type = "custom"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Domain struct {
|
||||||
|
ID string `gorm:"unique;primaryKey;autoIncrement"`
|
||||||
|
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
TargetCluster string // The proxy cluster this domain should be validated against
|
||||||
|
Type Type `gorm:"-"`
|
||||||
|
Validated bool
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
|
||||||
|
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||||
|
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||||
|
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||||
|
}
|
||||||
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
136
management/internals/modules/reverseproxy/domain/manager/api.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
manager Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||||
|
h := &handler{
|
||||||
|
manager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||||
|
switch t {
|
||||||
|
case domain.TypeCustom:
|
||||||
|
return api.ReverseProxyDomainTypeCustom
|
||||||
|
case domain.TypeFree:
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
// By default return as a "free" domain as that is more restrictive.
|
||||||
|
// TODO: is this correct?
|
||||||
|
return api.ReverseProxyDomainTypeFree
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||||
|
resp := api.ReverseProxyDomain{
|
||||||
|
Domain: d.Domain,
|
||||||
|
Id: d.ID,
|
||||||
|
Type: domainTypeToApi(d.Type),
|
||||||
|
Validated: d.Validated,
|
||||||
|
}
|
||||||
|
if d.TargetCluster != "" {
|
||||||
|
resp.TargetCluster = &d.TargetCluster
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := make([]api.ReverseProxyDomain, 0)
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, domainToApi(d))
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domainID := mux.Vars(r)["domainId"]
|
||||||
|
if domainID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type store interface {
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||||
|
|
||||||
|
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||||
|
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
||||||
|
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
||||||
|
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
||||||
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyManager interface {
|
||||||
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
store store
|
||||||
|
validator domain.Validator
|
||||||
|
proxyManager proxyManager
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
|
||||||
|
return Manager{
|
||||||
|
store: store,
|
||||||
|
proxyManager: proxyMgr,
|
||||||
|
validator: domain.Validator{
|
||||||
|
Resolver: net.DefaultResolver,
|
||||||
|
},
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ret []*domain.Domain
|
||||||
|
|
||||||
|
// Add connected proxy clusters as free domains.
|
||||||
|
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||||
|
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"proxyAllowList": allowList,
|
||||||
|
}).Debug("getting domains with proxy allow list")
|
||||||
|
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
Domain: cluster,
|
||||||
|
AccountID: accountID,
|
||||||
|
Type: domain.TypeFree,
|
||||||
|
Validated: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom domains.
|
||||||
|
for _, d := range domains {
|
||||||
|
ret = append(ret, &domain.Domain{
|
||||||
|
ID: d.ID,
|
||||||
|
Domain: d.Domain,
|
||||||
|
AccountID: accountID,
|
||||||
|
TargetCluster: d.TargetCluster,
|
||||||
|
Type: domain.TypeCustom,
|
||||||
|
Validated: d.Validated,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the target cluster is in the available clusters
|
||||||
|
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||||
|
}
|
||||||
|
clusterValid := false
|
||||||
|
for _, cluster := range allowList {
|
||||||
|
if cluster == targetCluster {
|
||||||
|
clusterValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clusterValid {
|
||||||
|
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt an initial validation against the specified cluster only
|
||||||
|
var validated bool
|
||||||
|
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
|
||||||
|
validated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
|
||||||
|
if err != nil {
|
||||||
|
return d, fmt.Errorf("create domain in store: %w", err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
||||||
|
// TODO: check for "no records" type error. Because that is a success condition.
|
||||||
|
return fmt.Errorf("delete domain from store: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("validate domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).Info("starting domain validation")
|
||||||
|
|
||||||
|
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
}).WithError(err).Error("get custom domain from store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate only against the domain's target cluster
|
||||||
|
targetCluster := d.TargetCluster
|
||||||
|
if targetCluster == "" {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Warn("domain has no target cluster set, skipping validation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Info("validating domain against target cluster")
|
||||||
|
|
||||||
|
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).Info("domain validated successfully")
|
||||||
|
d.Validated = true
|
||||||
|
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
}).WithError(err).Error("update custom domain in store")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"accountID": accountID,
|
||||||
|
"domainID": domainID,
|
||||||
|
"domain": d.Domain,
|
||||||
|
"targetCluster": targetCluster,
|
||||||
|
}).Warn("domain validation failed - CNAME does not match target cluster")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||||
|
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||||
|
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||||
|
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||||
|
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||||
|
}
|
||||||
|
if len(allowList) == 0 {
|
||||||
|
return "", fmt.Errorf("no proxy clusters available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
|
||||||
|
return cluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("list custom domains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
|
||||||
|
if valid {
|
||||||
|
return targetCluster, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
|
||||||
|
for _, customDomain := range customDomains {
|
||||||
|
if strings.HasSuffix(domain, "."+customDomain.Domain) {
|
||||||
|
return customDomain.TargetCluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
|
||||||
|
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
|
||||||
|
// It matches the domain suffix against available clusters and returns the matching cluster.
|
||||||
|
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
|
||||||
|
for _, cluster := range availableClusters {
|
||||||
|
if strings.HasSuffix(domain, "."+cluster) {
|
||||||
|
return cluster, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Validator struct {
|
||||||
|
Resolver resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValidator initializes a validator with a specific DNS Resolver.
|
||||||
|
// If a Validator is used without specifying a Resolver, then it will
|
||||||
|
// use the net.DefaultResolver.
|
||||||
|
func NewValidator(resolver resolver) *Validator {
|
||||||
|
return &Validator{
|
||||||
|
Resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid looks up the CNAME record for the passed domain with a prefix
|
||||||
|
// and compares it against the acceptable domains.
|
||||||
|
// If the returned CNAME matches any accepted domain, it will return true,
|
||||||
|
// otherwise, including in the event of a DNS error, it will return false.
|
||||||
|
// The comparison is very simple, so wildcards will not match if included
|
||||||
|
// in the acceptable domain list.
|
||||||
|
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
|
||||||
|
_, valid := v.ValidateWithCluster(ctx, domain, accept)
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
|
||||||
|
// Returns the cluster address and true if valid, or empty string and false if invalid.
|
||||||
|
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
|
||||||
|
if v.Resolver == nil {
|
||||||
|
v.Resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
lookupDomain := "validation." + domain
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("looking up CNAME for domain validation")
|
||||||
|
|
||||||
|
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"lookupDomain": lookupDomain,
|
||||||
|
}).WithError(err).Warn("CNAME lookup failed for domain validation")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
nakedCNAME := strings.TrimSuffix(cname, ".")
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": cname,
|
||||||
|
"nakedCNAME": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Debug("CNAME lookup result for domain validation")
|
||||||
|
|
||||||
|
for _, acceptDomain := range accept {
|
||||||
|
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
|
||||||
|
if nakedCNAME == normalizedAccept {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"cluster": acceptDomain,
|
||||||
|
}).Info("domain CNAME matched cluster")
|
||||||
|
return acceptDomain, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"cname": nakedCNAME,
|
||||||
|
"acceptList": accept,
|
||||||
|
}).Warn("domain CNAME does not match any accepted cluster")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user