mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68c481fa44 | ||
|
|
01a9cd4651 | ||
|
|
f53155562f | ||
|
|
edce11b34d | ||
|
|
841b2d26c6 | ||
|
|
d3eeb6d8ee | ||
|
|
7ebf37ef20 | ||
|
|
64b849c801 | ||
|
|
69d4b5d821 | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
2de1949018 | ||
|
|
fc88399c23 | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
1b96648d4d | ||
|
|
d2f9653cea | ||
|
|
194a986926 | ||
|
|
f7732557fa | ||
|
|
d488f58311 | ||
|
|
6fdc00ff41 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 | ||
|
|
0c990ab662 | ||
|
|
101c813e98 | ||
|
|
5333e55a81 | ||
|
|
81c11df103 | ||
|
|
f74bc48d16 | ||
|
|
0169e4540f | ||
|
|
cead3f38ee | ||
|
|
b55262d4a2 | ||
|
|
2248ff392f | ||
|
|
06966da012 | ||
|
|
d4f7df271a | ||
|
|
5299549eb6 | ||
|
|
7d791620a6 | ||
|
|
44ab454a13 | ||
|
|
11f50d6c38 | ||
|
|
05af39a69b | ||
|
|
074df56c3d | ||
|
|
2381e216e4 | ||
|
|
ded04b7627 |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
*.key
|
||||||
|
*.crt
|
||||||
|
*.p12
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./signal/...
|
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
61
.github/workflows/golang-test-linux.yml
vendored
61
.github/workflows/golang-test-linux.yml
vendored
@@ -97,6 +97,16 @@ jobs:
|
|||||||
working-directory: relay
|
working-directory: relay
|
||||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
|
- name: Build combined
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build combined 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -144,7 +154,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +214,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,6 +271,53 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
|
test_proxy:
|
||||||
|
name: "Proxy / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.0"
|
SIGN_PIPE_VER: "v0.1.1"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -160,7 +160,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Log in to the GitHub container registry
|
- name: Log in to the GitHub container registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
@@ -176,6 +176,7 @@ jobs:
|
|||||||
- name: Generate windows syso arm64
|
- name: Generate windows syso arm64
|
||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
@@ -185,6 +186,19 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
|
- name: Tag and push PR images (amd64 only)
|
||||||
|
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
run: |
|
||||||
|
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||||
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
|
IMG_NAME="${SRC%%:*}"
|
||||||
|
DST="${IMG_NAME}:${PR_TAG}"
|
||||||
|
echo "Tagging ${SRC} -> ${DST}"
|
||||||
|
docker tag "$SRC" "$DST"
|
||||||
|
docker push "$DST"
|
||||||
|
done
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
|
!proxy/web/dist/
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
181
.goreleaser.yaml
181
.goreleaser.yaml
@@ -106,6 +106,26 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-server
|
||||||
|
dir: combined
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=1
|
||||||
|
- >-
|
||||||
|
{{- if eq .Runtime.Goos "linux" }}
|
||||||
|
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||||
|
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
binary: netbird-server
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-upload
|
- id: netbird-upload
|
||||||
dir: upload-server
|
dir: upload-server
|
||||||
env: [CGO_ENABLED=0]
|
env: [CGO_ENABLED=0]
|
||||||
@@ -120,6 +140,20 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-proxy
|
||||||
|
dir: proxy/cmd/proxy
|
||||||
|
env: [CGO_ENABLED=0]
|
||||||
|
binary: netbird-proxy
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -520,6 +554,104 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
@@ -598,6 +730,18 @@ docker_manifests:
|
|||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird-server:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
@@ -675,6 +819,43 @@ docker_manifests:
|
|||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
|||||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||||
|
|
||||||
BSD 3-Clause License
|
BSD 3-Clause License
|
||||||
|
|||||||
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### Self-Host NetBird (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,7 @@ package android
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
go urlOpener.OnLoginSuccess()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
|||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -277,18 +276,15 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||||
needsLogin := false
|
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
|
||||||
err := internal.Login(ctx, config, "", "")
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
needsLogin = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check login required: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -300,23 +296,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastError error
|
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
lastError = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
if lastError != nil {
|
|
||||||
return fmt.Errorf("login failed: %v", lastError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -344,11 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
|
||||||
defer c()
|
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -30,6 +31,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
|
PeerStatusConnected = peer.StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -68,6 +77,10 @@ type Options struct {
|
|||||||
StatePath string
|
StatePath string
|
||||||
// DisableClientRoutes disables the client routes
|
// DisableClientRoutes disables the client routes
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
|
// BlockInbound blocks all inbound connections from peers
|
||||||
|
BlockInbound bool
|
||||||
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
|
WireguardPort *int
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -136,6 +149,8 @@ func New(opts Options) (*Client, error) {
|
|||||||
PreSharedKey: &opts.PreSharedKey,
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
WireguardPort: opts.WireguardPort,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -155,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
|
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,13 +192,17 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
|
||||||
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
|
||||||
c.recorder = recorder
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -335,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
recorder := c.recorder
|
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if recorder == nil {
|
|
||||||
return peer.FullStatus{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -350,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetFullStatus(), nil
|
return c.recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("acl manager init: %w", err)
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.initNoTrackChain(); err != nil {
|
||||||
|
return fmt.Errorf("init notrack chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := m.cleanupNoTrackChain(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.aclMgr.Reset(); err != nil {
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
}
|
}
|
||||||
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
chainNameRaw = "NETBIRD-RAW"
|
||||||
|
chainOUTPUT = "OUTPUT"
|
||||||
|
tableRaw = "raw"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
|
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||||
|
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||||
|
//
|
||||||
|
// Traffic flows that need NOTRACK:
|
||||||
|
//
|
||||||
|
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||||
|
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||||
|
// Matched by: sport=wgPort
|
||||||
|
//
|
||||||
|
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||||
|
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 3. Ingress: Packets to WireGuard
|
||||||
|
// dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||||
|
// dst=127.0.0.1:proxyPort
|
||||||
|
// Matched by: dport=proxyPort
|
||||||
|
//
|
||||||
|
// Rules are cleaned up when the firewall manager is closed.
|
||||||
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||||
|
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||||
|
|
||||||
|
// Egress rules: match outgoing loopback UDP packets
|
||||||
|
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||||
|
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||||
|
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ingress rules: match incoming loopback UDP packets
|
||||||
|
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||||
|
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||||
|
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||||
|
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initNoTrackChain() error {
|
||||||
|
if err := m.cleanupNoTrackChain(); err != nil {
|
||||||
|
log.Debugf("cleanup notrack chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||||
|
return fmt.Errorf("create chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||||
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
|
log.Debugf("delete orphan chain: %v", delErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add output jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||||
|
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||||
|
log.Debugf("delete output jump rule: %v", delErr)
|
||||||
|
}
|
||||||
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
|
log.Debugf("delete orphan chain: %v", delErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) cleanupNoTrackChain() error {
|
||||||
|
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check chain exists: %w", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||||
|
return fmt.Errorf("remove output jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||||
|
return fmt.Errorf("clear and delete chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -168,6 +168,10 @@ type Manager interface {
|
|||||||
|
|
||||||
// RemoveInboundDNAT removes inbound DNAT rule
|
// RemoveInboundDNAT removes inbound DNAT rule
|
||||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||||
|
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||||
|
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
@@ -48,8 +49,10 @@ type Manager struct {
|
|||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
router *router
|
router *router
|
||||||
aclManager *AclManager
|
aclManager *AclManager
|
||||||
|
notrackOutputChain *nftables.Chain
|
||||||
|
notrackPreroutingChain *nftables.Chain
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("acl manager init: %w", err)
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.initNoTrackChains(workTable); err != nil {
|
||||||
|
return fmt.Errorf("init notrack chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
// We only need to record minimal interface state for potential recreation.
|
// We only need to record minimal interface state for potential recreation.
|
||||||
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclManager.Flush()
|
if err := m.aclManager.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshNoTrackChains(); err != nil {
|
||||||
|
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
// AddDNATRule adds a DNAT rule
|
||||||
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
chainNameRawOutput = "netbird-raw-out"
|
||||||
|
chainNameRawPrerouting = "netbird-raw-pre"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
|
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||||
|
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||||
|
//
|
||||||
|
// Traffic flows that need NOTRACK:
|
||||||
|
//
|
||||||
|
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||||
|
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||||
|
// Matched by: sport=wgPort
|
||||||
|
//
|
||||||
|
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||||
|
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 3. Ingress: Packets to WireGuard
|
||||||
|
// dst=127.0.0.1:wgPort
|
||||||
|
// Matched by: dport=wgPort
|
||||||
|
//
|
||||||
|
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||||
|
// dst=127.0.0.1:proxyPort
|
||||||
|
// Matched by: dport=proxyPort
|
||||||
|
//
|
||||||
|
// Rules are cleaned up when the firewall manager is closed.
|
||||||
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||||
|
return fmt.Errorf("notrack chains not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||||
|
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||||
|
loopback := []byte{127, 0, 0, 1}
|
||||||
|
|
||||||
|
// Egress rules: match outgoing loopback UDP packets
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackOutputChain.Table,
|
||||||
|
Chain: m.notrackOutputChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackOutputChain.Table,
|
||||||
|
Chain: m.notrackOutputChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Ingress rules: match incoming loopback UDP packets
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackPreroutingChain.Table,
|
||||||
|
Chain: m.notrackPreroutingChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.notrackPreroutingChain.Table,
|
||||||
|
Chain: m.notrackPreroutingChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||||
|
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Notrack{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush notrack rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||||
|
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRawOutput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookOutput,
|
||||||
|
Priority: nftables.ChainPriorityRaw,
|
||||||
|
})
|
||||||
|
|
||||||
|
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRawPrerouting,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityRaw,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush chain creation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshNoTrackChains() error {
|
||||||
|
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := getTableName()
|
||||||
|
for _, c := range chains {
|
||||||
|
if c.Table.Name != tableName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch c.Name {
|
||||||
|
case chainNameRawOutput:
|
||||||
|
m.notrackOutputChain = c
|
||||||
|
case chainNameRawPrerouting:
|
||||||
|
m.notrackPreroutingChain = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nftRule.Handle == 0 {
|
if nftRule.Handle == 0 {
|
||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback ipset counter
|
r.rollbackRules(pair)
|
||||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||||
|
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||||
|
keys := []string{
|
||||||
|
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
rule, ok := r.rules[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
}
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||||
|
// counters will be off until the next successful removal or refresh cycle.
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback set counter
|
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||||
func (r *router) refreshRulesMap() error {
|
func (r *router) refreshRulesMap() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
newRules := make(map[string]*nftables.Rule)
|
||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list rules: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||||
|
// preserve existing entries for this chain since we can't verify their state
|
||||||
|
for k, v := range r.rules {
|
||||||
|
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||||
|
newRules[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
r.rules[string(rule.UserData)] = rule
|
newRules[string(rule.UserData)] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
r.rules = newRules
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var needsFlush bool
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
if dnatRule.Handle == 0 {
|
||||||
|
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(masqRule); err != nil {
|
if masqRule.Handle == 0 {
|
||||||
|
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if needsFlush {
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
rule, exists := r.rules[ruleID]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
return nil
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Add a real rule to the kernel
|
||||||
|
ruleKey, err := r.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
firewall.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&firewall.Port{Values: []uint16{80}},
|
||||||
|
firewall.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
|
staleKey := "stale-rule-that-does-not-exist"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||||
|
|
||||||
|
err = r.refreshRulesMap()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||||
|
|
||||||
|
realRule, ok := r.rules[ruleKey.ID()]
|
||||||
|
assert.True(t, ok, "real rule should still exist after refresh")
|
||||||
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0
|
||||||
|
staleKey := "stale-route-rule"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule should not return an error for stale handles
|
||||||
|
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||||
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
pair := firewall.RouterPair{
|
||||||
|
ID: "staletest",
|
||||||
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
rtr := manager.router
|
||||||
|
|
||||||
|
// First add succeeds
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Corrupt the handle to simulate stale state
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||||
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the same rule again should succeed despite stale handles
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||||
|
|
||||||
|
// Verify rules exist in kernel
|
||||||
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
found := 0
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||||
|
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||||
|
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||||
|
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||||
|
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||||
|
if t.tombstone.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && !conn.IsSupersededBy(flags) {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsTombstone() {
|
if !exists || conn.IsSupersededBy(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||||
|
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||||
|
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||||
|
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||||
|
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||||
|
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and gracefully close a connection (server-initiated close)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Server sends FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Client sends FIN-ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Server sends final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Connection should be tombstoned
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn, "old connection should still be in map")
|
||||||
|
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||||
|
|
||||||
|
// Now reuse the same port for a new connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// The old tombstoned entry should be replaced with a new one
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
// SYN-ACK for the new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
|
||||||
|
// Data transfer should work
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||||
|
require.True(t, valid, "data should be allowed on new connection")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and RST a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||||
|
|
||||||
|
// Reuse the same port
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := srcIP
|
||||||
|
serverIP := dstIP
|
||||||
|
clientPort := srcPort
|
||||||
|
serverPort := dstPort
|
||||||
|
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||||
|
|
||||||
|
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Server-initiated close to reach Closed/tombstoned:
|
||||||
|
// Server FIN (opposite dir) → CloseWait
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
// Client FIN-ACK (same dir as conn) → LastAck
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// New inbound connection on same ports
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
|
||||||
|
// Complete handshake: server SYN-ACK, then client ACK
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// Late ACK should be rejected (tombstoned)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||||
|
|
||||||
|
// Late outbound ACK should not create a new connection (not a SYN)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Active close: client (outbound initiator) sends FIN first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Server ACKs the FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Server sends its own FIN
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||||
|
|
||||||
|
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||||
|
|
||||||
|
// SYN-ACK for new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish outbound connection and close via active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||||
|
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||||
|
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||||
|
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||||
|
|
||||||
|
// Simulate what the filter does next: TrackInbound via the normal path
|
||||||
|
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||||
|
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||||
|
newConn := tracker.connections[invertedKey]
|
||||||
|
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,11 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
@@ -24,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -89,6 +93,7 @@ type Manager struct {
|
|||||||
incomingDenyRules map[netip.Addr]RuleSet
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
|||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||||
|
return existingRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: string(ruleKey),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.routeRulesMap[ruleKey] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleKey := nbid.RuleID(rule.ID())
|
||||||
|
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == string(ruleKey)
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
delete(m.routeRulesMap, ruleKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +586,48 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
|
// Must be called with m.mutex held.
|
||||||
|
func (m *Manager) resetState() {
|
||||||
|
maps.Clear(m.outgoingRules)
|
||||||
|
maps.Clear(m.incomingDenyRules)
|
||||||
|
maps.Clear(m.incomingRules)
|
||||||
|
maps.Clear(m.routeRulesMap)
|
||||||
|
m.routeRules = m.routeRules[:0]
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
|||||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||||
|
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||||
|
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.1.0/24"),
|
||||||
|
netip.MustParsePrefix("100.64.2.0/24"),
|
||||||
|
}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add rule first time
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
// Add the same rule again
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
// These should be the same (idempotent) like nftables/iptables implementations
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||||
|
// different parameters get distinct IDs.
|
||||||
|
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
|
// Add first rule
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add different rule (different destination)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-2"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Different rules should have different IDs")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||||
|
// rule during a network map update does not disrupt existing traffic.
|
||||||
|
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
|
// Re-add same rule (simulates network map update)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||||
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
|
if rule1.ID() != rule2.ID() {
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.True(t, passAfter,
|
||||||
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
|
// returns the same rule without duplicating.
|
||||||
|
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call blockInvalidRouted directly multiple times
|
||||||
|
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule3)
|
||||||
|
|
||||||
|
// All should return the same rule
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||||
|
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||||
|
|
||||||
|
// Should have exactly 1 route rule
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||||
|
|
||||||
|
// Verify the rule blocks traffic to the WG network
|
||||||
|
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||||
|
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||||
|
// EnableRouting multiple times (as happens on each route update) does not
|
||||||
|
// accumulate duplicate block rules in the routeRules slice.
|
||||||
|
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount,
|
||||||
|
"Repeated EnableRouting should not accumulate block rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||||
|
// rule multiple times does not create duplicate entries.
|
||||||
|
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Simulate 5 network map updates with the same route rule
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||||
|
// after adding it multiple times works correctly.
|
||||||
|
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add same rule twice
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
|
// Delete using first reference
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify traffic no longer passes
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestManager(t *testing.T) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
|
}
|
||||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
|
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||||
|
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Add multiple deny rules for different ports
|
||||||
|
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
|
// Delete the first deny rule
|
||||||
|
err = m.DeletePeerRule(rule1[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount = len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
|
// Delete the second deny rule
|
||||||
|
err = m.DeletePeerRule(rule2[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, exists := m.incomingDenyRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
|
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||||
|
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
// Add a deny rule
|
||||||
|
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add an allow rule
|
||||||
|
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete them (simulating ACL manager cleanup)
|
||||||
|
for _, r := range rules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
for _, r := range allowRules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
|
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||||
|
// IP are stored in separate maps and don't interfere with each other.
|
||||||
|
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
|
// Add allow rule for port 80
|
||||||
|
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add deny rule for port 22
|
||||||
|
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
m.mutex.RLock()
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
|
// Delete allow rule should not affect deny rule
|
||||||
|
err = m.DeletePeerRule(allowRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
|
// Delete deny rule
|
||||||
|
err = m.DeletePeerRule(denyRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, denyExists := m.incomingDenyRules[addr]
|
||||||
|
_, allowExists := m.incomingRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
|
require.False(t, allowExists, "Allow rules should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,9 +18,18 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
logChannelSize = 1000
|
defaultLogChanSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getLogChannelSize() int {
|
||||||
|
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultLogChanSize
|
||||||
|
}
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -69,7 +80,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, logChannelSize),
|
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
host, portStr, err := net.SplitHostPort(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse endpoint: %v", err)
|
log.Errorf("failed to parse endpoint: %v", err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
|||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDeviceFilter constructor function
|
// newDeviceFilter constructor function
|
||||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying tun device exactly once.
|
||||||
|
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||||
|
// and multiple code paths can trigger Close on the same device.
|
||||||
|
func (d *FilteredDevice) Close() error {
|
||||||
|
var err error
|
||||||
|
d.closeOnce.Do(func() {
|
||||||
|
err = d.Device.Close()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Read wraps read method with filtering feature
|
// Read wraps read method with filtering feature
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
log.Debugf("failed to close tun device: %v", cErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
@@ -50,6 +51,7 @@ func ValidateMTU(mtu uint16) error {
|
|||||||
|
|
||||||
type wgProxyFactory interface {
|
type wgProxyFactory interface {
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
|
GetProxyPort() uint16
|
||||||
Free() error
|
Free() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +82,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
|||||||
return w.wgProxyFactory.GetProxy()
|
return w.wgProxyFactory.GetProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||||
|
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||||
|
func (w *WGIface) GetProxyPort() uint16 {
|
||||||
|
return w.wgProxyFactory.GetProxyPort()
|
||||||
|
}
|
||||||
|
|
||||||
// GetBind returns the EndpointManager userspace bind mode.
|
// GetBind returns the EndpointManager userspace bind mode.
|
||||||
func (w *WGIface) GetBind() device.EndpointManager {
|
func (w *WGIface) GetBind() device.EndpointManager {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
@@ -221,6 +229,10 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nbnetstack.IsEnabled() {
|
||||||
|
return errors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.waitUntilRemoved(); err != nil {
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
if err := w.Destroy(); err != nil {
|
if err := w.Destroy(); err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nsTunDev, tunNet, nil
|
return t.tundev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
|||||||
@@ -114,34 +114,21 @@ func (p *ProxyBind) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
ep, err := addrToEndpoint(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to start package redirection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
ep, err := addrToEndpoint(endpoint)
|
p.wgCurrentUsed = ep
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to convert endpoint address: %v", err)
|
|
||||||
} else {
|
|
||||||
p.wgCurrentUsed = ep
|
|
||||||
}
|
|
||||||
|
|
||||||
p.pausedCond.Signal()
|
p.pausedCond.Signal()
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
|
||||||
if addr == nil {
|
|
||||||
return nil, errors.New("nil address")
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
|
||||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ProxyBind) CloseConn() error {
|
func (p *ProxyBind) CloseConn() error {
|
||||||
if p.cancel == nil {
|
if p.cancel == nil {
|
||||||
return fmt.Errorf("proxy not started")
|
return fmt.Errorf("proxy not started")
|
||||||
@@ -225,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
|||||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||||
return &netipAddr, nil
|
return &netipAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||||
|
if addr == nil {
|
||||||
|
return nil, fmt.Errorf("invalid address")
|
||||||
|
}
|
||||||
|
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||||
|
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -26,13 +24,10 @@ const (
|
|||||||
loopbackAddr = "127.0.0.1"
|
loopbackAddr = "127.0.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
|
||||||
)
|
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
|
proxyPort int
|
||||||
mtu uint16
|
mtu uint16
|
||||||
|
|
||||||
ebpfManager ebpfMgr.Manager
|
ebpfManager ebpfMgr.Manager
|
||||||
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
|
|||||||
turnConnMutex sync.Mutex
|
turnConnMutex sync.Mutex
|
||||||
|
|
||||||
lastUsedPort uint16
|
lastUsedPort uint16
|
||||||
rawConn net.PacketConn
|
rawConnIPv4 net.PacketConn
|
||||||
|
rawConnIPv6 net.PacketConn
|
||||||
conn transport.UDPConn
|
conn transport.UDPConn
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
|||||||
// Listen load ebpf program and listen the proxy
|
// Listen load ebpf program and listen the proxy
|
||||||
func (p *WGEBPFProxy) Listen() error {
|
func (p *WGEBPFProxy) Listen() error {
|
||||||
pl := portLookup{}
|
pl := portLookup{}
|
||||||
wgPorxyPort, err := pl.searchFreePort()
|
proxyPort, err := pl.searchFreePort()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.proxyPort = proxyPort
|
||||||
|
|
||||||
|
// Prepare IPv4 raw socket (required)
|
||||||
|
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
// Prepare IPv6 raw socket (optional)
|
||||||
|
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||||
|
}
|
||||||
|
if p.rawConnIPv6 != nil {
|
||||||
|
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := net.UDPAddr{
|
addr := net.UDPAddr{
|
||||||
Port: wgPorxyPort,
|
Port: proxyPort,
|
||||||
IP: net.ParseIP(loopbackAddr),
|
IP: net.ParseIP(loopbackAddr),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
p.conn = conn
|
p.conn = conn
|
||||||
|
|
||||||
go p.proxyToRemote()
|
go p.proxyToRemote()
|
||||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
result = multierror.Append(result, err)
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.rawConn.Close(); err != nil {
|
if p.rawConnIPv4 != nil {
|
||||||
result = multierror.Append(result, err)
|
if err := p.rawConnIPv4.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.rawConnIPv6 != nil {
|
||||||
|
if err := p.rawConnIPv6.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns the proxy listening port.
|
||||||
|
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||||
|
return uint16(p.proxyPort)
|
||||||
|
}
|
||||||
|
|
||||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||||
// From this go routine has only one instance.
|
// From this go routine has only one instance.
|
||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
@@ -216,34 +241,3 @@ generatePort:
|
|||||||
}
|
}
|
||||||
return p.lastUsedPort, nil
|
return p.lastUsedPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
|
||||||
payload := gopacket.Payload(data)
|
|
||||||
ipH := &layers.IPv4{
|
|
||||||
DstIP: localHostNetIP,
|
|
||||||
SrcIP: endpointAddr.IP,
|
|
||||||
Version: 4,
|
|
||||||
TTL: 64,
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
udpH := &layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
|
||||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
layerBuffer := gopacket.NewSerializeBuffer()
|
|
||||||
|
|
||||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("serialize layers: %w", err)
|
|
||||||
}
|
|
||||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
|
||||||
return fmt.Errorf("write to raw conn: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,12 +10,89 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||||
|
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||||
|
|
||||||
|
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||||
|
localHostNetIPv6 = net.ParseIP("::1")
|
||||||
|
|
||||||
|
serializeOpts = gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||||
|
type PacketHeaders struct {
|
||||||
|
ipH gopacket.SerializableLayer
|
||||||
|
udpH *layers.UDP
|
||||||
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
localHostAddr net.IP
|
||||||
|
isIPv4 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||||
|
var ipH gopacket.SerializableLayer
|
||||||
|
var networkLayer gopacket.NetworkLayer
|
||||||
|
var localHostAddr net.IP
|
||||||
|
var isIPv4 bool
|
||||||
|
|
||||||
|
// Check if source address is IPv4 or IPv6
|
||||||
|
if endpoint.IP.To4() != nil {
|
||||||
|
// IPv4 path
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
DstIP: localHostNetIPv4,
|
||||||
|
SrcIP: endpoint.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv4
|
||||||
|
networkLayer = ipv4
|
||||||
|
localHostAddr = localHostNetIPv4
|
||||||
|
isIPv4 = true
|
||||||
|
} else {
|
||||||
|
// IPv6 path
|
||||||
|
ipv6 := &layers.IPv6{
|
||||||
|
DstIP: localHostNetIPv6,
|
||||||
|
SrcIP: endpoint.IP,
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv6
|
||||||
|
networkLayer = ipv6
|
||||||
|
localHostAddr = localHostNetIPv6
|
||||||
|
isIPv4 = false
|
||||||
|
}
|
||||||
|
|
||||||
|
udpH := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(endpoint.Port),
|
||||||
|
DstPort: layers.UDPPort(localWGListenPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PacketHeaders{
|
||||||
|
ipH: ipH,
|
||||||
|
udpH: udpH,
|
||||||
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
localHostAddr: localHostAddr,
|
||||||
|
isIPv4: isIPv4,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
type ProxyWrapper struct {
|
type ProxyWrapper struct {
|
||||||
wgeBPFProxy *WGEBPFProxy
|
wgeBPFProxy *WGEBPFProxy
|
||||||
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgRelayedEndpointAddr *net.UDPAddr
|
wgRelayedEndpointAddr *net.UDPAddr
|
||||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
headers *PacketHeaders
|
||||||
|
headerCurrentUsed *PacketHeaders
|
||||||
|
rawConn net.PacketConn
|
||||||
|
|
||||||
paused bool
|
paused bool
|
||||||
pausedCond *sync.Cond
|
pausedCond *sync.Cond
|
||||||
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
|||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
|
||||||
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add turn conn: %w", err)
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create packet sender: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if required raw connection is available
|
||||||
|
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||||
|
return errIPv6ConnNotAvailable
|
||||||
|
}
|
||||||
|
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||||
|
return errIPv4ConnNotAvailable
|
||||||
|
}
|
||||||
|
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.wgRelayedEndpointAddr = addr
|
p.wgRelayedEndpointAddr = addr
|
||||||
return err
|
p.headers = headers
|
||||||
|
p.rawConn = p.selectRawConn(headers)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
|
|||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
p.headerCurrentUsed = p.headers
|
||||||
|
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
@@ -91,12 +188,32 @@ func (p *ProxyWrapper) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
if endpoint == nil || endpoint.IP == nil {
|
||||||
|
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create packet headers: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if required raw connection is available
|
||||||
|
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||||
|
log.Error(errIPv6ConnNotAvailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||||
|
log.Error(errIPv4ConnNotAvailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.pausedCond.L.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
|
|
||||||
if endpoint != nil && endpoint.IP != nil {
|
p.headerCurrentUsed = header
|
||||||
p.wgEndpointCurrentUsedAddr = endpoint
|
p.rawConn = p.selectRawConn(header)
|
||||||
}
|
|
||||||
|
|
||||||
p.pausedCond.Signal()
|
p.pausedCond.Signal()
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
@@ -138,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
|||||||
p.pausedCond.Wait()
|
p.pausedCond.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -164,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||||
|
defer func() {
|
||||||
|
if err := header.layerBuffer.Clear(); err != nil {
|
||||||
|
log.Errorf("failed to clear layer buffer: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := gopacket.Payload(data)
|
||||||
|
|
||||||
|
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||||
|
return fmt.Errorf("serialize layers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||||
|
return fmt.Errorf("write to raw conn: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||||
|
if header.isIPv4 {
|
||||||
|
return p.wgeBPFProxy.rawConnIPv4
|
||||||
|
}
|
||||||
|
return p.wgeBPFProxy.rawConnIPv6
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||||
|
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||||
|
if w.ebpfProxy == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return w.ebpfProxy.GetProxyPort()
|
||||||
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
if w.ebpfProxy == nil {
|
if w.ebpfProxy == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
|
|||||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||||
|
func (w *USPFactory) GetProxyPort() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,43 +8,87 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||||
|
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||||
|
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||||
|
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||||
|
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||||
// Create a raw socket.
|
// Create a raw socket.
|
||||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||||
if err != nil {
|
if isIPv4 {
|
||||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind the socket to the "lo" interface.
|
// Bind the socket to the "lo" interface.
|
||||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
// Set the fwmark on the socket.
|
||||||
err = nbnet.SetSocketOpt(fd)
|
err = nbnet.SetSocketOpt(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert the file descriptor to a PacketConn.
|
// Convert the file descriptor to a PacketConn.
|
||||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||||
if file == nil {
|
if file == nil {
|
||||||
|
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("converting fd to file failed")
|
return nil, fmt.Errorf("converting fd to file failed")
|
||||||
}
|
}
|
||||||
packetConn, err := net.FilePacketConn(file)
|
packetConn, err := net.FilePacketConn(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := file.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close file: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||||
|
if closeErr := file.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||||
|
}
|
||||||
|
|
||||||
return packetConn, nil
|
return packetConn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
353
client/iface/wgproxy/redirect_test.go
Normal file
353
client/iface/wgproxy/redirect_test.go
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||||
|
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||||
|
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||||
|
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||||
|
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||||
|
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return addr1.String() == addr2.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare IP and Port, ignoring zone
|
||||||
|
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51850
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51851
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51852
|
||||||
|
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51853
|
||||||
|
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||||
|
// It verifies that:
|
||||||
|
// 1. Initial traffic from relay connection works
|
||||||
|
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||||
|
// 3. Multiple packets are correctly redirected with the new source address
|
||||||
|
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||||
|
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||||
|
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: wgPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
defer wgListener4.Close()
|
||||||
|
|
||||||
|
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("::1"),
|
||||||
|
Port: wgPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
defer wgListener6.Close()
|
||||||
|
|
||||||
|
// Determine which listener to use based on the NetBird address IP version
|
||||||
|
// (this is where initial traffic will come from before RedirectAs is called)
|
||||||
|
var wgListener *net.UDPConn
|
||||||
|
if p2pEndpoint.IP.To4() == nil {
|
||||||
|
wgListener = wgListener6
|
||||||
|
} else {
|
||||||
|
wgListener = wgListener4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create relay server and connection
|
||||||
|
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0, // Random port
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay server: %v", err)
|
||||||
|
}
|
||||||
|
defer relayServer.Close()
|
||||||
|
|
||||||
|
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay connection: %v", err)
|
||||||
|
}
|
||||||
|
defer relayConn.Close()
|
||||||
|
|
||||||
|
// Add TURN connection to proxy
|
||||||
|
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||||
|
t.Fatalf("failed to add TURN connection: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("failed to close proxy connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start the proxy
|
||||||
|
proxy.Work()
|
||||||
|
|
||||||
|
// Phase 1: Test initial relay traffic
|
||||||
|
msgFromRelay := []byte("hello from relay")
|
||||||
|
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||||
|
t.Fatalf("failed to write to relay server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set read deadline to avoid hanging
|
||||||
|
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _, err := wgListener4.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(msgFromRelay) {
|
||||||
|
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(msgFromRelay) {
|
||||||
|
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: Redirect to p2p endpoint
|
||||||
|
proxy.RedirectAs(p2pEndpoint)
|
||||||
|
|
||||||
|
// Give the proxy a moment to process the redirect
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Phase 3: Test redirected traffic
|
||||||
|
redirectedMessages := [][]byte{
|
||||||
|
[]byte("redirected message 1"),
|
||||||
|
[]byte("redirected message 2"),
|
||||||
|
[]byte("redirected message 3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range redirectedMessages {
|
||||||
|
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||||
|
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify message content
|
||||||
|
if string(buf[:n]) != string(msg) {
|
||||||
|
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify source address matches p2p endpoint (this is the key test)
|
||||||
|
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||||
|
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||||
|
t.Errorf("message %d: expected source address %s, got %s",
|
||||||
|
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||||
|
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||||
|
wgPort := 51856
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create WireGuard listener
|
||||||
|
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: wgPort,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||||
|
}
|
||||||
|
defer wgListener.Close()
|
||||||
|
|
||||||
|
// Create relay server and connection
|
||||||
|
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay server: %v", err)
|
||||||
|
}
|
||||||
|
defer relayServer.Close()
|
||||||
|
|
||||||
|
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create relay connection: %v", err)
|
||||||
|
}
|
||||||
|
defer relayConn.Close()
|
||||||
|
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||||
|
t.Fatalf("failed to add TURN connection: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("failed to close proxy connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy.Work()
|
||||||
|
|
||||||
|
// Test switching between multiple endpoints - using addresses in local subnet
|
||||||
|
endpoints := []*net.UDPAddr{
|
||||||
|
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||||
|
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||||
|
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, endpoint := range endpoints {
|
||||||
|
proxy.RedirectAs(endpoint)
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
msg := []byte("test message")
|
||||||
|
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||||
|
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:n]) != string(msg) {
|
||||||
|
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !compareUDPAddr(srcAddr, endpoint) {
|
||||||
|
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||||
|
i, endpoint.String(), srcAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
|||||||
// the connection is complete, an error is returned. Once successfully
|
// the connection is complete, an error is returned. Once successfully
|
||||||
// connected, any expiration of the context will not affect the
|
// connected, any expiration of the context will not affect the
|
||||||
// connection.
|
// connection.
|
||||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -19,37 +19,56 @@ var (
|
|||||||
FixLengths: true,
|
FixLengths: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
localHostNetIPAddr = &net.IPAddr{
|
localHostNetIPAddrV4 = &net.IPAddr{
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
}
|
}
|
||||||
|
localHostNetIPAddrV6 = &net.IPAddr{
|
||||||
|
IP: net.ParseIP("::1"),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type SrcFaker struct {
|
type SrcFaker struct {
|
||||||
srcAddr *net.UDPAddr
|
srcAddr *net.UDPAddr
|
||||||
|
|
||||||
rawSocket net.PacketConn
|
rawSocket net.PacketConn
|
||||||
ipH gopacket.SerializableLayer
|
ipH gopacket.SerializableLayer
|
||||||
udpH gopacket.SerializableLayer
|
udpH gopacket.SerializableLayer
|
||||||
layerBuffer gopacket.SerializeBuffer
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
localHostAddr *net.IPAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
// Create only the raw socket for the address family we need
|
||||||
|
var rawSocket net.PacketConn
|
||||||
|
var err error
|
||||||
|
var localHostAddr *net.IPAddr
|
||||||
|
|
||||||
|
if srcAddr.IP.To4() != nil {
|
||||||
|
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||||
|
localHostAddr = localHostNetIPAddrV4
|
||||||
|
} else {
|
||||||
|
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||||
|
localHostAddr = localHostNetIPAddrV6
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f := &SrcFaker{
|
f := &SrcFaker{
|
||||||
srcAddr: srcAddr,
|
srcAddr: srcAddr,
|
||||||
rawSocket: rawSocket,
|
rawSocket: rawSocket,
|
||||||
ipH: ipH,
|
ipH: ipH,
|
||||||
udpH: udpH,
|
udpH: udpH,
|
||||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
localHostAddr: localHostAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
return f, nil
|
return f, nil
|
||||||
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||||
}
|
}
|
||||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||||
}
|
}
|
||||||
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||||
ipH := &layers.IPv4{
|
var ipH gopacket.SerializableLayer
|
||||||
DstIP: net.ParseIP("127.0.0.1"),
|
var networkLayer gopacket.NetworkLayer
|
||||||
SrcIP: srcAddr.IP,
|
|
||||||
Version: 4,
|
// Check if source IP is IPv4 or IPv6
|
||||||
TTL: 64,
|
if srcAddr.IP.To4() != nil {
|
||||||
Protocol: layers.IPProtocolUDP,
|
// IPv4
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
DstIP: localHostNetIPAddrV4.IP,
|
||||||
|
SrcIP: srcAddr.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv4
|
||||||
|
networkLayer = ipv4
|
||||||
|
} else {
|
||||||
|
// IPv6
|
||||||
|
ipv6 := &layers.IPv6{
|
||||||
|
DstIP: localHostNetIPAddrV6.IP,
|
||||||
|
SrcIP: srcAddr.IP,
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
ipH = ipv6
|
||||||
|
networkLayer = ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
udpH := &layers.UDP{
|
udpH := &layers.UDP{
|
||||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||||
}
|
}
|
||||||
|
|
||||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||||
|
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||||
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||||
|
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||||
|
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||||
|
// up when they're removed from the network map in a subsequent update.
|
||||||
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: add deny and accept rules
|
||||||
|
networkMap1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap1, false)
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||||
|
|
||||||
|
// Second update: remove the deny rule, keep only accept
|
||||||
|
networkMap2 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap2, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Should have 1 rule after removing deny rule")
|
||||||
|
|
||||||
|
// Third update: remove all rules
|
||||||
|
networkMap3 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{},
|
||||||
|
FirewallRulesIsEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap3, false)
|
||||||
|
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||||
|
"Should have 0 rules after removing all rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||||
|
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||||
|
// one added without leaking.
|
||||||
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: accept rule
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
|
// Second update: change to deny (same IP/port/proto, different action)
|
||||||
|
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Changing action should result in exactly 1 rule, not 2")
|
||||||
|
}
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth manages authentication operations with the management server
|
||||||
|
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||||
|
type Auth struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
client *mgm.GrpcClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
mgmURL *url.URL
|
||||||
|
mgmTLSEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth creates a new Auth instance that manages authentication flows
|
||||||
|
// It establishes a connection to the management server that will be reused for all operations
|
||||||
|
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||||
|
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||||
|
// Validate WireGuard private key
|
||||||
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine TLS setting based on URL scheme
|
||||||
|
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||||
|
|
||||||
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
|
return &Auth{
|
||||||
|
client: mgmClient,
|
||||||
|
config: config,
|
||||||
|
privateKey: myPrivateKey,
|
||||||
|
mgmURL: mgmURL,
|
||||||
|
mgmTLSEnabled: mgmTLSEnabled,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the management client connection
|
||||||
|
func (a *Auth) Close() error {
|
||||||
|
a.mutex.Lock()
|
||||||
|
defer a.mutex.Unlock()
|
||||||
|
|
||||||
|
if a.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||||
|
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||||
|
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||||
|
var supportsSSO bool
|
||||||
|
|
||||||
|
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
// Try PKCE flow first
|
||||||
|
_, err := a.getPKCEFlow(client)
|
||||||
|
if err == nil {
|
||||||
|
supportsSSO = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if PKCE is not supported
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
// PKCE not supported, try Device flow
|
||||||
|
_, err = a.getDeviceFlow(client)
|
||||||
|
if err == nil {
|
||||||
|
supportsSSO = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Device flow is also not supported
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
// Neither PKCE nor Device flow is supported
|
||||||
|
supportsSSO = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Device flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return supportsSSO, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||||
|
// This avoids creating a new connection to the management server
|
||||||
|
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||||
|
var flow OAuthFlow
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
if forceDeviceAuth {
|
||||||
|
flow, err = a.getDeviceFlow(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try PKCE flow first
|
||||||
|
flow, err = a.getPKCEFlow(client)
|
||||||
|
if err != nil {
|
||||||
|
// If PKCE not supported, try Device flow
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
flow, err = a.getDeviceFlow(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return flow, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var needsLogin bool
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
|
if isLoginNeeded(err) {
|
||||||
|
needsLogin = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
needsLogin = false
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return needsLogin, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login attempts to log in or register the client with the management server
|
||||||
|
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var isAuthError bool
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
|
log.Debugf("peer registration required")
|
||||||
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
|
if err != nil {
|
||||||
|
isAuthError = isPermissionDenied(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
isAuthError = isPermissionDenied(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isAuthError = false
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return err, isAuthError
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &PKCEAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
ClientCertPair: a.config.ClientCertKeyPair,
|
||||||
|
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||||
|
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePKCEConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve device flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &DeviceAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
Domain: protoConfig.Domain,
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep compatibility with older management versions
|
||||||
|
if config.Scope == "" {
|
||||||
|
config.Scope = "openid"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateDeviceAuthConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
a.setSystemInfoFlags(sysInfo)
|
||||||
|
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
|
return serverKey, loginResp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||||
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
|
if err != nil && jwtToken == "" {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
a.setSystemInfoFlags(info)
|
||||||
|
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed registering peer %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("peer has been successfully registered on Management Service")
|
||||||
|
|
||||||
|
return loginResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||||
|
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||||
|
info.SetFlags(
|
||||||
|
a.config.RosenpassEnabled,
|
||||||
|
a.config.RosenpassPermissive,
|
||||||
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.DisableClientRoutes,
|
||||||
|
a.config.DisableServerRoutes,
|
||||||
|
a.config.DisableDNS,
|
||||||
|
a.config.DisableFirewall,
|
||||||
|
a.config.BlockLANAccess,
|
||||||
|
a.config.BlockInbound,
|
||||||
|
a.config.LazyConnectionEnabled,
|
||||||
|
a.config.EnableSSHRoot,
|
||||||
|
a.config.EnableSSHSFTP,
|
||||||
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
|
a.config.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconnect closes the current connection and creates a new one
|
||||||
|
// It checks if the brokenClient is still the current client before reconnecting
|
||||||
|
// to avoid multiple threads reconnecting unnecessarily
|
||||||
|
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||||
|
a.mutex.Lock()
|
||||||
|
defer a.mutex.Unlock()
|
||||||
|
|
||||||
|
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||||
|
if a.client != brokenClient {
|
||||||
|
log.Debugf("client already reconnected by another thread, skipping")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new connection FIRST, before closing the old one
|
||||||
|
// This ensures a.client is never nil, preventing panics in other threads
|
||||||
|
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||||
|
// Keep the old client if reconnection fails
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close old connection AFTER new one is successfully created
|
||||||
|
oldClient := a.client
|
||||||
|
a.client = mgmClient
|
||||||
|
|
||||||
|
if oldClient != nil {
|
||||||
|
if err := oldClient.Close(); err != nil {
|
||||||
|
log.Debugf("error closing old connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||||
|
func isConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// These error codes indicate connection issues
|
||||||
|
return s.Code() == codes.Unavailable ||
|
||||||
|
s.Code() == codes.DeadlineExceeded ||
|
||||||
|
s.Code() == codes.Canceled ||
|
||||||
|
s.Code() == codes.Internal
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRetry wraps an operation with exponential backoff retry logic
|
||||||
|
// It automatically reconnects on connection errors
|
||||||
|
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||||
|
backoffSettings := &backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 500 * time.Millisecond,
|
||||||
|
RandomizationFactor: 0.5,
|
||||||
|
Multiplier: 1.5,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 2 * time.Minute,
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}
|
||||||
|
backoffSettings.Reset()
|
||||||
|
|
||||||
|
return backoff.RetryNotify(
|
||||||
|
func() error {
|
||||||
|
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||||
|
a.mutex.RLock()
|
||||||
|
currentClient := a.client
|
||||||
|
a.mutex.RUnlock()
|
||||||
|
|
||||||
|
if currentClient == nil {
|
||||||
|
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute operation with the captured client
|
||||||
|
err := operation(currentClient)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||||
|
if isConnectionError(err) {
|
||||||
|
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||||
|
|
||||||
|
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||||
|
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||||
|
return reconnectErr
|
||||||
|
}
|
||||||
|
// Return the original error to trigger retry with the new connection
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For authentication errors, don't retry
|
||||||
|
if isAuthenticationError(err) {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
backoff.WithContext(backoffSettings, ctx),
|
||||||
|
func(err error, duration time.Duration) {
|
||||||
|
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||||
|
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||||
|
func isAuthenticationError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||||
|
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||||
|
func isPermissionDenied(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLoginNeeded(err error) bool {
|
||||||
|
return isAuthenticationError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegistrationNeeded(err error) bool {
|
||||||
|
return isPermissionDenied(err)
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,12 +25,56 @@ const (
|
|||||||
|
|
||||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||||
|
|
||||||
|
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
|
type DeviceAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Domain An IDP API domain
|
||||||
|
// Deprecated. Use OIDCConfigEndpoint instead
|
||||||
|
Domain string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||||
|
DeviceAuthEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||||
|
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.DeviceAuthEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||||
// for the Device Authorization Flow.
|
// for the Device Authorization Flow.
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
providerConfig internal.DeviceAuthProviderConfig
|
providerConfig DeviceAuthProviderConfig
|
||||||
|
HTTPClient HTTPClient
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
|||||||
return d.providerConfig.ClientID
|
return d.providerConfig.ClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the device authorization flow
|
||||||
|
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
d.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||||
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
|
||||||
tokenResponse, err := d.requestToken(info)
|
tokenResponse, err := d.requestToken(info)
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
err: testCase.inputReqError,
|
err: testCase.inputReqError,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := &DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: expectedAudience,
|
||||||
Audience: expectedAudience,
|
ClientID: expectedClientID,
|
||||||
ClientID: expectedClientID,
|
Scope: expectedScope,
|
||||||
Scope: expectedScope,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
countResBody: testCase.inputCountResBody,
|
countResBody: testCase.inputCountResBody,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: testCase.inputAudience,
|
||||||
Audience: testCase.inputAudience,
|
ClientID: clientID,
|
||||||
ClientID: clientID,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
Scope: "openid",
|
||||||
Scope: "openid",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
|||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
pkceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
return pkceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch s, ok := gstatus.FromError(err); {
|
switch s, ok := gstatus.FromError(err); {
|
||||||
case ok && s.Code() == codes.NotFound:
|
case ok && s.Code() == codes.NotFound:
|
||||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
deviceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
return deviceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
@@ -35,17 +34,67 @@ const (
|
|||||||
defaultPKCETimeoutSeconds = 300
|
defaultPKCETimeoutSeconds = 300
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||||
|
type PKCEAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// RedirectURL handles authorization code from IDP manager
|
||||||
|
RedirectURLs []string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// ClientCertPair is used for mTLS authentication to the IDP
|
||||||
|
ClientCertPair *tls.Certificate
|
||||||
|
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||||
|
DisablePromptLogin bool
|
||||||
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
|
LoginFlag common.LoginFlag
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePKCEConfig validates PKCE provider configuration
|
||||||
|
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.AuthorizationEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||||
|
}
|
||||||
|
if config.RedirectURLs == nil {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||||
// the Authorization Code Flow with PKCE.
|
// the Authorization Code Flow with PKCE.
|
||||||
type PKCEAuthorizationFlow struct {
|
type PKCEAuthorizationFlow struct {
|
||||||
providerConfig internal.PKCEAuthProviderConfig
|
providerConfig PKCEAuthProviderConfig
|
||||||
state string
|
state string
|
||||||
codeVerifier string
|
codeVerifier string
|
||||||
oAuthConfig *oauth2.Config
|
oAuthConfig *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
excludedRanges := getSystemExcludedPortRanges()
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||||
|
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
p.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
|
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tokenChan := make(chan *oauth2.Token, 1)
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
|
|
||||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
defer func() {
|
defer func() {
|
||||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
go p.startServer(server, tokenChan, errChan)
|
go p.startServer(server, tokenChan, errChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case token := <-tokenChan:
|
case token := <-tokenChan:
|
||||||
return p.parseOAuthToken(token)
|
return p.parseOAuthToken(token)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseExcludedPortRanges(t *testing.T) {
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
|||||||
|
|
||||||
availablePort := 65432
|
availablePort := 65432
|
||||||
|
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
|
|||||||
@@ -1,136 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
|
||||||
type DeviceAuthorizationFlow struct {
|
|
||||||
Provider string
|
|
||||||
ProviderConfig DeviceAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
|
||||||
type DeviceAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Domain An IDP API domain
|
|
||||||
// Deprecated. Use OIDCConfigEndpoint instead
|
|
||||||
Domain string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
|
||||||
DeviceAuthEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
|
||||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve device flow: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
|
||||||
|
|
||||||
ProviderConfig: DeviceAuthProviderConfig{
|
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
|
||||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
|
||||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep compatibility with older management versions
|
|
||||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
|
||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceAuthorizationFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.Audience == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
|
||||||
}
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.DeviceAuthEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
matchSubdomains: false,
|
matchSubdomains: false,
|
||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD exact match",
|
||||||
|
handlerDomain: "example.x.",
|
||||||
|
queryDomain: "example.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD subdomain match",
|
||||||
|
handlerDomain: "example.x.",
|
||||||
|
queryDomain: "sub.example.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single letter TLD wildcard match",
|
||||||
|
handlerDomain: "*.example.x.",
|
||||||
|
queryDomain: "sub.example.x.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two letter domain labels",
|
||||||
|
handlerDomain: "a.b.",
|
||||||
|
queryDomain: "a.b.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character domain",
|
||||||
|
handlerDomain: "x.",
|
||||||
|
queryDomain: "x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character domain with subdomain match",
|
||||||
|
handlerDomain: "x.",
|
||||||
|
queryDomain: "sub.x.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@@ -38,6 +40,9 @@ const (
|
|||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
systemDNSSettings SystemDNSSettings
|
systemDNSSettings SystemDNSSettings
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (*systemConfigurator, error) {
|
func newHostManager() (*systemConfigurator, error) {
|
||||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var dnsSettings SystemDNSSettings
|
var dnsSettings SystemDNSSettings
|
||||||
|
var serverAddresses []netip.Addr
|
||||||
inSearchDomainsArray := false
|
inSearchDomainsArray := false
|
||||||
inServerAddressesArray := false
|
inServerAddressesArray := false
|
||||||
|
|
||||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||||
} else if inServerAddressesArray {
|
} else if inServerAddressesArray {
|
||||||
address := strings.Split(line, " : ")[1]
|
address := strings.Split(line, " : ")[1]
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||||
dnsSettings.ServerIP = ip.Unmap()
|
ip = ip.Unmap()
|
||||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
serverAddresses = append(serverAddresses, ip)
|
||||||
|
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||||
|
dnsSettings.ServerIP = ip
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
// default to 53 port
|
// default to 53 port
|
||||||
dnsSettings.ServerPort = DefaultPort
|
dnsSettings.ServerPort = DefaultPort
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.origNameservers = serverAddresses
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
return dnsSettings, nil
|
return dnsSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return slices.Clone(s.origNameservers)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
err := s.addDNSState(key, domains, ip, port, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
|||||||
_, err := cmd.CombinedOutput()
|
_, err := cmd.CombinedOutput()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameservers(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
origNameservers: []netip.Addr{
|
||||||
|
netip.MustParseAddr("8.8.8.8"),
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
assert.Len(t, servers, 2)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||||
|
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
|
||||||
|
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||||
|
|
||||||
|
for _, server := range servers {
|
||||||
|
assert.True(t, server.IsValid(), "server address should be valid")
|
||||||
|
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
_ = sm.Stop(context.Background())
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return configurator, sm, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
routeAll bool
|
||||||
|
}{
|
||||||
|
{"routeall_false", false},
|
||||||
|
{"routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
for _, srv := range initialServers {
|
||||||
|
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.routeAll,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= 2; i++ {
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initialRoute bool
|
||||||
|
}{
|
||||||
|
{"start_with_routeall_false", false},
|
||||||
|
{"start_with_routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.initialRoute,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// First apply
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle RouteAll
|
||||||
|
config.RouteAll = !tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle back
|
||||||
|
config.RouteAll = tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
for _, srv := range servers {
|
||||||
|
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -27,6 +29,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
// ProbeAvailability tests each upstream group's servers for availability
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
// and deactivates the group if no server responds
|
// and deactivates the group if no server responds
|
||||||
func (s *DefaultServer) ProbeAvailability() {
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
|
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||||
|
skipProbe, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||||
|
}
|
||||||
|
if skipProbe {
|
||||||
|
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, mux := range s.dnsMuxMap {
|
for _, mux := range s.dnsMuxMap {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -615,7 +630,7 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
s.registerFallback(config)
|
s.registerFallback(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -624,6 +639,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
|
|
||||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||||
if len(originalNameservers) == 0 {
|
if len(originalNameservers) == 0 {
|
||||||
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,21 @@ import (
|
|||||||
|
|
||||||
type MockResponseWriter struct {
|
type MockResponseWriter struct {
|
||||||
WriteMsgFunc func(m *dns.Msg) error
|
WriteMsgFunc func(m *dns.Msg) error
|
||||||
|
lastResponse *dns.Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||||
|
rw.lastResponse = m
|
||||||
if rw.WriteMsgFunc != nil {
|
if rw.WriteMsgFunc != nil {
|
||||||
return rw.WriteMsgFunc(m)
|
return rw.WriteMsgFunc(m)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||||
|
return rw.lastResponse
|
||||||
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
|||||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
qname := strings.ToLower(question.Name)
|
||||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
// query doesn't match any configured domain
|
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||||
f.cache.set(domain, question.Qtype, result.IPs)
|
f.cache.set(qname, question.Qtype, result.IPs)
|
||||||
|
|
||||||
return resp
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
|
type udpResponseWriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
query *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||||
|
opt := u.query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ResponseWriter.WriteMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt := query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
// client advertised a larger EDNS0 buffer
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if our response is too big, truncate and set the TC bit
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
if resp != nil {
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
}
|
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else if resp != nil {
|
} else {
|
||||||
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2, "expected response to be written")
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2)
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
var writtenResp *dns.Msg
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writtenResp = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "Expected response to be written")
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||||
if resp != nil && writtenResp == nil {
|
|
||||||
writtenResp = resp
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
writeCalled := false
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writeCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
@@ -505,6 +506,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return fmt.Errorf("up wg interface: %w", err)
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up notrack rules immediately after proxy is listening to prevent
|
||||||
|
// conntrack entries from being created before the rules are in place
|
||||||
|
e.setupWGProxyNoTrack()
|
||||||
|
|
||||||
// Set the WireGuard interface for rosenpass after interface is up
|
// Set the WireGuard interface for rosenpass after interface is up
|
||||||
if e.rpManager != nil {
|
if e.rpManager != nil {
|
||||||
e.rpManager.SetInterface(e.wgInterface)
|
e.rpManager.SetInterface(e.wgInterface)
|
||||||
@@ -539,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
wgIfaceName := e.wgInterface.Name()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.triggerClientRestart()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -569,9 +575,11 @@ func (e *Engine) createFirewall() error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
return fmt.Errorf("create firewall manager: %w", err)
|
||||||
return nil
|
}
|
||||||
|
if e.firewall == nil {
|
||||||
|
return fmt.Errorf("create firewall manager: received nil manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.initFirewall(); err != nil {
|
if err := e.initFirewall(); err != nil {
|
||||||
@@ -617,6 +625,23 @@ func (e *Engine) initFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||||
|
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||||
|
func (e *Engine) setupWGProxyNoTrack() {
|
||||||
|
if e.firewall == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPort := e.wgInterface.GetProxyPort()
|
||||||
|
if proxyPort == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||||
|
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) blockLanAccess() {
|
func (e *Engine) blockLanAccess() {
|
||||||
if e.config.BlockInbound {
|
if e.config.BlockInbound {
|
||||||
// no need to set up extra deny rules if inbound is already blocked in general
|
// no need to set up extra deny rules if inbound is already blocked in general
|
||||||
@@ -805,6 +830,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.Infof("sync finished in %s", time.Since(started))
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -994,7 +1023,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(state)
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
@@ -1644,6 +1673,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
|
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||||
@@ -1894,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor {
|
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
|
|
||||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerInfo) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
|||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configMgr := sshconfig.New()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ type MockWGIface struct {
|
|||||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
|
GetProxyPortFunc func() uint16
|
||||||
GetNetFunc func() *netstack.Net
|
GetNetFunc func() *netstack.Net
|
||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
@@ -203,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
|||||||
return m.GetProxyFunc()
|
return m.GetProxyFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||||
|
if m.GetProxyPortFunc != nil {
|
||||||
|
return m.GetProxyPortFunc()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||||
return m.GetNetFunc()
|
return m.GetNetFunc()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
|||||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
|
GetProxyPort() uint16
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemoveEndpointAddress(key string) error
|
RemoveEndpointAddress(key string) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindListener is only used on Windows and JS platforms:
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
// - JS: Cannot listen to UDP sockets
|
// - JS: Cannot listen to UDP sockets
|
||||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
// gateway points to, preventing them from reaching the loopback interface.
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
// BindListener bypasses this by passing data directly through the bind.
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsLoginRequired check that the server is support SSO or not
|
|
||||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
|
||||||
mgmURL := config.ManagementURL
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if isLoginNeeded(err) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login or register the client
|
|
||||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
|
||||||
log.Debugf("peer registration required")
|
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return mgmClient, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
|
||||||
sysInfo.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
|
||||||
return serverKey, loginResp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
|
||||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
|
||||||
if err != nil && jwtToken == "" {
|
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("sending peer registration request to Management Service")
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
info.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed registering peer %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("peer has been successfully registered on Management Service")
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLoginNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRegistrationNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||||
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
|
||||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||||
conn.handleConfigurationFailure(err, wgProxy)
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
conn.wgProxyRelay.RedirectAs(ep)
|
conn.wgProxyRelay.RedirectAs(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.enableWgWatcherIfNeeded()
|
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.SetConnected()
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||||
|
|
||||||
|
conn.enableWgWatcherIfNeeded()
|
||||||
|
|
||||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.enableWgWatcherIfNeeded()
|
|
||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ThreadSafeAgent) Close() error {
|
|
||||||
var err error
|
|
||||||
a.once.Do(func() {
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- a.Agent.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-done:
|
|
||||||
case <-time.After(iceAgentCloseTimeout):
|
|
||||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if agent == nil {
|
||||||
|
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||||
|
}
|
||||||
|
|
||||||
return &ThreadSafeAgent{Agent: agent}, nil
|
return &ThreadSafeAgent{Agent: agent}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
// Defensive check to prevent nil pointer dereference
|
||||||
|
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||||
|
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||||
|
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||||
|
agent := a.Agent
|
||||||
|
if agent == nil {
|
||||||
|
log.Warnf("ICE agent is nil during close, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateICECredentials() (string, string, error) {
|
func GenerateICECredentials() (string, string, error) {
|
||||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if err := w.agent.Close(); err != nil {
|
if w.agent != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
sessionID, err := NewICESessionID()
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
ProviderConfig PKCEAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
|
||||||
type PKCEAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
|
||||||
AuthorizationEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// RedirectURL handles authorization code from IDP manager
|
|
||||||
RedirectURLs []string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// ClientCertPair is used for mTLS authentication to the IDP
|
|
||||||
ClientCertPair *tls.Certificate
|
|
||||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
|
||||||
DisablePromptLogin bool
|
|
||||||
// LoginFlag is used to configure the PKCE flow login behavior
|
|
||||||
LoginFlag common.LoginFlag
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
|
||||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authFlow := PKCEAuthorizationFlow{
|
|
||||||
ProviderConfig: PKCEAuthProviderConfig{
|
|
||||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
|
||||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
|
||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
ClientCertPair: clientCert,
|
|
||||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
|
||||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return authFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.AuthorizationEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
|
||||||
}
|
|
||||||
if config.RedirectURLs == nil {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.AdminURL == nil {
|
if config.AdminURL == nil {
|
||||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
|
var once sync.Once
|
||||||
|
var wgIface *net.Interface
|
||||||
|
toInterface := func() *net.Interface {
|
||||||
|
once.Do(func() {
|
||||||
|
wgIface = m.wgInterface.ToInterface()
|
||||||
|
})
|
||||||
|
return wgIface
|
||||||
|
}
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_CLONING != 0 {
|
if flags&unix.RTF_CLONING != 0 {
|
||||||
flagStrs = append(flagStrs, "C")
|
flagStrs = append(flagStrs, "C")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_WASCLONED != 0 {
|
if flags&unix.RTF_WASCLONED != 0 {
|
||||||
flagStrs = append(flagStrs, "W")
|
flagStrs = append(flagStrs, "W")
|
||||||
}
|
}
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -4,17 +4,18 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
|||||||
func formatBSDFlags(flags int) string {
|
func formatBSDFlags(flags int) string {
|
||||||
var flagStrs []string
|
var flagStrs []string
|
||||||
|
|
||||||
if flags&syscall.RTF_UP != 0 {
|
if flags&unix.RTF_UP != 0 {
|
||||||
flagStrs = append(flagStrs, "U")
|
flagStrs = append(flagStrs, "U")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_GATEWAY != 0 {
|
if flags&unix.RTF_GATEWAY != 0 {
|
||||||
flagStrs = append(flagStrs, "G")
|
flagStrs = append(flagStrs, "G")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_HOST != 0 {
|
if flags&unix.RTF_HOST != 0 {
|
||||||
flagStrs = append(flagStrs, "H")
|
flagStrs = append(flagStrs, "H")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_REJECT != 0 {
|
if flags&unix.RTF_REJECT != 0 {
|
||||||
flagStrs = append(flagStrs, "R")
|
flagStrs = append(flagStrs, "R")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
if flags&unix.RTF_DYNAMIC != 0 {
|
||||||
flagStrs = append(flagStrs, "D")
|
flagStrs = append(flagStrs, "D")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_MODIFIED != 0 {
|
if flags&unix.RTF_MODIFIED != 0 {
|
||||||
flagStrs = append(flagStrs, "M")
|
flagStrs = append(flagStrs, "M")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_STATIC != 0 {
|
if flags&unix.RTF_STATIC != 0 {
|
||||||
flagStrs = append(flagStrs, "S")
|
flagStrs = append(flagStrs, "S")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LLINFO != 0 {
|
if flags&unix.RTF_LLINFO != 0 {
|
||||||
flagStrs = append(flagStrs, "L")
|
flagStrs = append(flagStrs, "L")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_LOCAL != 0 {
|
if flags&unix.RTF_LOCAL != 0 {
|
||||||
flagStrs = append(flagStrs, "l")
|
flagStrs = append(flagStrs, "l")
|
||||||
}
|
}
|
||||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||||
flagStrs = append(flagStrs, "B")
|
flagStrs = append(flagStrs, "B")
|
||||||
}
|
}
|
||||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||||
|
if flags&unix.RTF_PROTO1 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "1")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO2 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "2")
|
||||||
|
}
|
||||||
|
if flags&unix.RTF_PROTO3 != 0 {
|
||||||
|
flagStrs = append(flagStrs, "3")
|
||||||
|
}
|
||||||
|
|
||||||
if len(flagStrs) == 0 {
|
if len(flagStrs) == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||||
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
|||||||
return false, errors.New("not supported on mobile platforms")
|
return false, errors.New("not supported on mobile platforms")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Interface monitor: skipped in netstack mode")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
if ifaceName == "" {
|
if ifaceName == "" {
|
||||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||||
return false, errors.New("empty interface name")
|
return false, errors.New("empty interface name")
|
||||||
|
|||||||
@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
|
||||||
|
return true // Assume login is required if we can't create auth client
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||||
// If the check fails, assume login is required to be safe
|
// If the check fails, assume login is required to be safe
|
||||||
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
|
|||||||
|
|
||||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||||
go func() {
|
go func() {
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwtToken := tokenInfo.GetTokenToUse()
|
jwtToken := tokenInfo.GetTokenToUse()
|
||||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
|
||||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
// which are blocked by the tvOS sandbox in App Group containers
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
// LoginSync performs a synchronous login check without UI interaction
|
// LoginSync performs a synchronous login check without UI interaction
|
||||||
// Used for background VPN connection where user should already be authenticated
|
// Used for background VPN connection where user should already be authenticated
|
||||||
func (a *Auth) LoginSync() error {
|
func (a *Auth) LoginSync() error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
|
|||||||
return fmt.Errorf("not authenticated")
|
return fmt.Errorf("not authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// PermissionDenied means registration is required or peer is blocked
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return fmt.Errorf("authentication error: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("login failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||||
var needsLogin bool
|
|
||||||
|
|
||||||
// Create context with device name if provided
|
// Create context with device name if provided
|
||||||
ctx := a.ctx
|
ctx := a.ctx
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
|||||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(ctx, func() (err error) {
|
|
||||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
// check if we need to generate JWT token
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(ctx, func() error {
|
err, isAuthError := authClient.Login(ctx, "", jwtToken)
|
||||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// PermissionDenied means registration is required or peer is blocked
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return fmt.Errorf("authentication error: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("login failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
|||||||
|
|
||||||
const authInfoRequestTimeout = 30 * time.Second
|
const authInfoRequestTimeout = 30 * time.Second
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||||
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetConfigJSON returns the current config as a JSON string.
|
// GetConfigJSON returns the current config as a JSON string.
|
||||||
// This can be used by the caller to persist the config via alternative storage
|
// This can be used by the caller to persist the config via alternative storage
|
||||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||||
|
|||||||
@@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
|||||||
|
|
||||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||||
var status internal.StatusType
|
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
log.Errorf("failed to create auth client: %v", err)
|
||||||
|
return internal.StatusLoginFailed, err
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
var status internal.StatusType
|
||||||
|
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
log.Warnf("failed login: %v", err)
|
log.Warnf("failed login: %v", err)
|
||||||
status = internal.StatusNeedsLogin
|
status = internal.StatusNeedsLogin
|
||||||
} else {
|
} else {
|
||||||
@@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
s.oauthAuthFlow.waitCancel()
|
s.oauthAuthFlow.waitCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
waitCTX, cancel := context.WithCancel(ctx)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
|
|||||||
@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||||
// Create a backend session to mirror the client's session request.
|
|
||||||
// This keeps the connection alive on the server side while port forwarding channels operate.
|
|
||||||
serverSession, err := sshClient.NewSession()
|
serverSession, err := sshClient.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||||
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
|
|||||||
}
|
}
|
||||||
defer func() { _ = serverSession.Close() }()
|
defer func() { _ = serverSession.Close() }()
|
||||||
|
|
||||||
<-session.Context().Done()
|
serverSession.Stdin = session
|
||||||
|
serverSession.Stdout = session
|
||||||
|
serverSession.Stderr = session.Stderr()
|
||||||
|
|
||||||
if err := session.Exit(0); err != nil {
|
if err := serverSession.Shell(); err != nil {
|
||||||
log.Debugf("session exit: %v", err)
|
log.Debugf("start shell: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- serverSession.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-session.Context().Done():
|
||||||
|
return
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("shell session: %v", err)
|
||||||
|
p.handleProxyExitCode(session, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleCommand executes an SSH command with privilege validation
|
// handleExecution executes an SSH command or shell with privilege validation
|
||||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
|
||||||
hasPty := winCh != nil
|
hasPty := winCh != nil
|
||||||
|
|
||||||
commandType := "command"
|
commandType := "command"
|
||||||
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||||
|
|
||||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||||
|
|
||||||
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
ptyReq, _, _ := session.Pty()
|
|
||||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||||
logger.Debugf("%s execution completed", commandType)
|
logger.Debugf("%s execution completed", commandType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
return nil, nil, errors.New("no user in privilege result")
|
return nil, nil, errors.New("no user in privilege result")
|
||||||
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
|
|||||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||||
if hasPty && !s.suSupportsPty {
|
if hasPty && !s.suSupportsPty {
|
||||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, cleanup, nil
|
return cmd, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try su first for system integration (PAM/audit) when privileged
|
// Try su first for system integration (PAM/audit) when privileged
|
||||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil || privilegeResult.UsedFallback {
|
if err != nil || privilegeResult.UsedFallback {
|
||||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
logger.Debugf("su command failed, falling back to executor: %v", err)
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, cleanup, nil
|
return cmd, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, func() {}, nil
|
return cmd, func() {}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,17 +15,17 @@ import (
|
|||||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||||
|
|
||||||
// createSuCommand is not supported on JS/WASM
|
// createSuCommand is not supported on JS/WASM
|
||||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||||
return nil, errNotSupported
|
return nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// createExecutorCommand is not supported on JS/WASM
|
// createExecutorCommand is not supported on JS/WASM
|
||||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||||
return nil, nil, errNotSupported
|
return nil, nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv is not supported on JS/WASM
|
// prepareCommandEnv is not supported on JS/WASM
|
||||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
|||||||
return isUtilLinux
|
return isUtilLinux
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching
|
// createSuCommand creates a command using su - for privilege switching.
|
||||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
|
if err := validateUsername(localUser.Username); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||||
|
}
|
||||||
|
|
||||||
suPath, err := exec.LookPath("su")
|
suPath, err := exec.LookPath("su")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("su command not available: %w", err)
|
return nil, fmt.Errorf("su command not available: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
command := session.RawCommand()
|
args := []string{"-"}
|
||||||
if command == "" {
|
|
||||||
return nil, fmt.Errorf("no command specified for su execution")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{"-l"}
|
|
||||||
if hasPty && s.suSupportsPty {
|
if hasPty && s.suSupportsPty {
|
||||||
args = append(args, "--pty")
|
args = append(args, "--pty")
|
||||||
}
|
}
|
||||||
args = append(args, localUser.Username, "-c", command)
|
args = append(args, localUser.Username)
|
||||||
|
|
||||||
|
command := session.RawCommand()
|
||||||
|
if command != "" {
|
||||||
|
args = append(args, "-c", command)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("creating su command: %s %v", suPath, args)
|
||||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||||
cmd.Dir = localUser.HomeDir
|
cmd.Dir = localUser.HomeDir
|
||||||
|
|
||||||
return cmd, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
// getShellCommandArgs returns the shell command and arguments for executing a command string.
|
||||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||||
if cmdString == "" {
|
if cmdString == "" {
|
||||||
return []string{shell, "-l"}
|
return []string{shell}
|
||||||
}
|
}
|
||||||
return []string{shell, "-l", "-c", cmdString}
|
return []string{shell, "-c", cmdString}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
|
||||||
|
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
|
||||||
|
cmd := exec.CommandContext(ctx, shell, args[1:]...)
|
||||||
|
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||||
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
env = append(env, prepareSSHEnv(session)...)
|
env = append(env, prepareSSHEnv(session)...)
|
||||||
for _, v := range session.Environ() {
|
for _, v := range session.Environ() {
|
||||||
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
|
|||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Pty command creation failed: %v", err)
|
logger.Errorf("Pty command creation failed: %v", err)
|
||||||
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
|
||||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
logger.Debugf("session close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if _, err := io.Copy(session, ptmx); err != nil {
|
if _, err := io.Copy(session, ptmx); err != nil {
|
||||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||||
logger.Warnf("Pty output copy error: %v", err)
|
logger.Warnf("Pty output copy error: %v", err)
|
||||||
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
s.handlePtyCommandCompletion(logger, session, err)
|
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugf("Pty command execution failed: %v", err)
|
logger.Debugf("Pty command execution failed: %v", err)
|
||||||
s.handleSessionExit(session, err, logger)
|
s.handleSessionExit(session, err, logger)
|
||||||
return
|
} else {
|
||||||
|
logger.Debugf("Pty command completed successfully")
|
||||||
|
if err := session.Exit(0); err != nil {
|
||||||
|
logSessionExitError(logger, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal completion
|
// Close PTY to unblock io.Copy goroutines
|
||||||
logger.Debugf("Pty command completed successfully")
|
if err := ptyMgr.Close(); err != nil {
|
||||||
if err := session.Exit(0); err != nil {
|
logger.Debugf("Pty close after completion: %v", err)
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,32 +20,32 @@ import (
|
|||||||
|
|
||||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
|
||||||
userToken, err := s.getUserToken(username, domain)
|
userToken, err := s.getUserToken(logger, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get user token: %w", err)
|
return nil, fmt.Errorf("get user token: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(userToken); err != nil {
|
if err := windows.CloseHandle(userToken); err != nil {
|
||||||
log.Debugf("close user token: %v", err)
|
logger.Debugf("close user token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
|
||||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||||
}
|
}
|
||||||
|
|
||||||
envMap := make(map[string]string)
|
envMap := make(map[string]string)
|
||||||
|
|
||||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||||
log.Debugf("failed to load system environment from registry: %v", err)
|
logger.Debugf("failed to load system environment from registry: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||||
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getUserToken creates a user token for the specified user.
|
// getUserToken creates a user token for the specified user.
|
||||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
token, err := privilegeDropper.createToken(username, domain)
|
token, err := privilegeDropper.createToken(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||||
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
userEnv, err := s.getUserEnvironment(username, domain)
|
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
|
|||||||
return env
|
return env
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
if privilegeResult.User == nil {
|
if privilegeResult.User == nil {
|
||||||
logger.Errorf("no user in privilege result")
|
logger.Errorf("no user in privilege result")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := session.Command()
|
|
||||||
shell := getUserShell(privilegeResult.User.Uid)
|
shell := getUserShell(privilegeResult.User.Uid)
|
||||||
|
logger.Infof("starting interactive shell: %s", shell)
|
||||||
|
|
||||||
if len(cmd) == 0 {
|
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||||
logger.Infof("starting interactive shell: %s", shell)
|
|
||||||
} else {
|
|
||||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
|||||||
return []string{shell, "-Command", cmdString}
|
return []string{shell, "-Command", cmdString}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
|
||||||
logger.Info("starting interactive shell")
|
|
||||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
|
||||||
}
|
|
||||||
|
|
||||||
type PtyExecutionRequest struct {
|
type PtyExecutionRequest struct {
|
||||||
Shell string
|
Shell string
|
||||||
Command string
|
Command string
|
||||||
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
|
|||||||
Domain string
|
Domain string
|
||||||
}
|
}
|
||||||
|
|
||||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
|
||||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||||
|
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create user token: %w", err)
|
return fmt.Errorf("create user token: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(userToken); err != nil {
|
if err := windows.CloseHandle(userToken); err != nil {
|
||||||
log.Debugf("close user token: %v", err)
|
logger.Debugf("close user token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server := &Server{}
|
server := &Server{}
|
||||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||||
userEnv = os.Environ()
|
userEnv = os.Environ()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
|
|||||||
Environment: userEnv,
|
Environment: userEnv,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserHomeFromEnv(env []string) string {
|
func getUserHomeFromEnv(env []string) string {
|
||||||
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := log.WithField("pid", cmd.Process.Pid)
|
|
||||||
|
|
||||||
if err := cmd.Process.Kill(); err != nil {
|
if err := cmd.Process.Kill(); err != nil {
|
||||||
logger.Debugf("kill process failed: %v", err)
|
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
command := session.RawCommand()
|
|
||||||
if command == "" {
|
|
||||||
logger.Error("no command specified for PTY execution")
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
|
||||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
logger.Errorf("no user in privilege result")
|
logger.Errorf("no user in privilege result")
|
||||||
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
|
|||||||
|
|
||||||
req := PtyExecutionRequest{
|
req := PtyExecutionRequest{
|
||||||
Shell: shell,
|
Shell: shell,
|
||||||
Command: command,
|
Command: session.RawCommand(),
|
||||||
Width: ptyReq.Window.Width,
|
Width: ptyReq.Window.Width,
|
||||||
Height: ptyReq.Window.Height,
|
Height: ptyReq.Window.Height,
|
||||||
Username: username,
|
Username: username,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
|
||||||
logger.Errorf("ConPty execution failed: %v", err)
|
logger.Errorf("ConPty execution failed: %v", err)
|
||||||
if err := session.Exit(1); err != nil {
|
if err := session.Exit(1); err != nil {
|
||||||
logSessionExitError(logger, err)
|
logSessionExitError(logger, err)
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,25 +26,67 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMain handles package-level setup and cleanup
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||||
// This happens when running tests as non-privileged user with fallback
|
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||||
|
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||||
|
// propagate exit codes.
|
||||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||||
// Just exit with error to break the recursion
|
runTestExecutor()
|
||||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
return
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run tests
|
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
|
||||||
// Cleanup any created test users
|
|
||||||
testutil.CleanupTestUsers()
|
testutil.CleanupTestUsers()
|
||||||
|
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runTestExecutor emulates the netbird executor for tests.
|
||||||
|
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||||
|
func runTestExecutor() {
|
||||||
|
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||||
|
|
||||||
|
shell := "/bin/sh"
|
||||||
|
var command string
|
||||||
|
for i := 3; i < len(os.Args); i++ {
|
||||||
|
switch os.Args[i] {
|
||||||
|
case "--shell":
|
||||||
|
if i+1 < len(os.Args) {
|
||||||
|
shell = os.Args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--cmd":
|
||||||
|
if i+1 < len(os.Args) {
|
||||||
|
command = os.Args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if command == "" {
|
||||||
|
cmd = exec.Command(shell)
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(shell, "-c", command)
|
||||||
|
}
|
||||||
|
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
os.Exit(exitErr.ExitCode())
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||||
func TestSSHServerCompatibility(t *testing.T) {
|
func TestSSHServerCompatibility(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
|||||||
return createTempKeyFileFromBytes(t, privateKey)
|
return createTempKeyFileFromBytes(t, privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||||
|
// This ensures our implementation matches OpenSSH behavior for:
|
||||||
|
// - ssh host command (no PTY - default when no TTY)
|
||||||
|
// - ssh -T host command (explicit no PTY)
|
||||||
|
// - ssh -t host command (force PTY)
|
||||||
|
// - ssh -T host (no PTY shell - our implementation)
|
||||||
|
func TestSSHPtyModes(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSSHClientAvailable() {
|
||||||
|
t.Skip("SSH client not available on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
|
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||||
|
}
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||||
|
defer cleanupKey()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
|
baseArgs := []string{
|
||||||
|
"-i", clientKeyFile,
|
||||||
|
"-p", portStr,
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
|
"-o", "ConnectTimeout=5",
|
||||||
|
"-o", "BatchMode=yes",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "no_pty_default")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "explicit_no_pty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command_force_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "force_pty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||||
|
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||||
|
|
||||||
|
stdin, err := cmd.StdinPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer stdin.Close()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||||
|
assert.NoError(t, err, "write echo command")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
_, err = stdin.Write([]byte("exit 0\n"))
|
||||||
|
assert.NoError(t, err, "write exit command")
|
||||||
|
}()
|
||||||
|
|
||||||
|
output, _ := io.ReadAll(stdout)
|
||||||
|
err = cmd.Wait()
|
||||||
|
|
||||||
|
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.Error(t, err, "Command should exit with non-zero")
|
||||||
|
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||||
|
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.Error(t, err, "PTY command should exit with non-zero")
|
||||||
|
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||||
|
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||||
|
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
var stdout, stderr strings.Builder
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||||
|
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||||
|
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||||
|
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||||
|
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "stdout_msg")
|
||||||
|
assert.Contains(t, string(output), "stderr_msg")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||||
type PrivilegeDropper struct{}
|
type PrivilegeDropper struct {
|
||||||
|
logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||||
|
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||||
|
|
||||||
// NewPrivilegeDropper creates a new privilege dropper
|
// NewPrivilegeDropper creates a new privilege dropper
|
||||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||||
return &PrivilegeDropper{}
|
pd := &PrivilegeDropper{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pd)
|
||||||
|
}
|
||||||
|
return pd
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the PrivilegeDropper
|
||||||
|
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||||
|
return func(pd *PrivilegeDropper) {
|
||||||
|
pd.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log returns the logger, falling back to standard logger if none set
|
||||||
|
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||||
|
if pd.logger != nil {
|
||||||
|
return pd.logger
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||||
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
|
|||||||
|
|
||||||
var execCmd *exec.Cmd
|
var execCmd *exec.Cmd
|
||||||
if config.Command == "" {
|
if config.Command == "" {
|
||||||
os.Exit(ExitCodeSuccess)
|
execCmd = exec.CommandContext(ctx, config.Shell)
|
||||||
|
} else {
|
||||||
|
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||||
}
|
}
|
||||||
|
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
|
||||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
|
||||||
execCmd.Stdin = os.Stdin
|
execCmd.Stdin = os.Stdin
|
||||||
execCmd.Stdout = os.Stdout
|
execCmd.Stdout = os.Stdout
|
||||||
execCmd.Stderr = os.Stderr
|
execCmd.Stderr = os.Stderr
|
||||||
|
|
||||||
cmdParts := strings.Fields(config.Command)
|
if config.Command == "" {
|
||||||
safeCmd := safeLogCommand(cmdParts)
|
log.Tracef("executing login shell: %s", execCmd.Path)
|
||||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
} else {
|
||||||
|
cmdParts := strings.Fields(config.Command)
|
||||||
|
safeCmd := safeLogCommand(cmdParts)
|
||||||
|
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||||
|
}
|
||||||
if err := execCmd.Run(); err != nil {
|
if err := execCmd.Run(); err != nil {
|
||||||
var exitError *exec.ExitError
|
var exitError *exec.ExitError
|
||||||
if errors.As(err, &exitError) {
|
if errors.As(err, &exitError) {
|
||||||
|
|||||||
@@ -28,22 +28,45 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WindowsExecutorConfig struct {
|
type WindowsExecutorConfig struct {
|
||||||
Username string
|
Username string
|
||||||
Domain string
|
Domain string
|
||||||
WorkingDir string
|
WorkingDir string
|
||||||
Shell string
|
Shell string
|
||||||
Command string
|
Command string
|
||||||
Args []string
|
Args []string
|
||||||
Interactive bool
|
Pty bool
|
||||||
Pty bool
|
PtyWidth int
|
||||||
PtyWidth int
|
PtyHeight int
|
||||||
PtyHeight int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PrivilegeDropper struct{}
|
type PrivilegeDropper struct {
|
||||||
|
logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||||
return &PrivilegeDropper{}
|
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||||
|
|
||||||
|
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||||
|
pd := &PrivilegeDropper{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pd)
|
||||||
|
}
|
||||||
|
return pd
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the PrivilegeDropper
|
||||||
|
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||||
|
return func(pd *PrivilegeDropper) {
|
||||||
|
pd.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log returns the logger, falling back to standard logger if none set
|
||||||
|
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||||
|
if pd.logger != nil {
|
||||||
|
return pd.logger
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -56,7 +79,6 @@ const (
|
|||||||
|
|
||||||
// Common error messages
|
// Common error messages
|
||||||
commandFlag = "-Command"
|
commandFlag = "-Command"
|
||||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
|
||||||
convertUsernameError = "convert username to UTF16: %w"
|
convertUsernameError = "convert username to UTF16: %w"
|
||||||
convertDomainError = "convert domain to UTF16: %w"
|
convertDomainError = "convert domain to UTF16: %w"
|
||||||
)
|
)
|
||||||
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
|||||||
shellArgs = []string{shell}
|
shellArgs = []string{shell}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||||
|
|
||||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||||
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
|
|||||||
|
|
||||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
userCpn := buildUserCpn(username, domain)
|
userCpn := buildUserCpn(username, domain)
|
||||||
|
|
||||||
pd := NewPrivilegeDropper()
|
pd := NewPrivilegeDropper(WithLogger(logger))
|
||||||
isDomainUser := !pd.isLocalUser(domain)
|
isDomainUser := !pd.isLocalUser(domain)
|
||||||
|
|
||||||
lsaHandle, err := initializeLsaConnection()
|
lsaHandle, err := initializeLsaConnection()
|
||||||
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUserCpn constructs the user principal name
|
// buildUserCpn constructs the user principal name
|
||||||
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||||
if isDomainUser {
|
if isDomainUser {
|
||||||
return prepareDomainS4ULogon(username, domain)
|
return prepareDomainS4ULogon(logger, username, domain)
|
||||||
}
|
}
|
||||||
return prepareLocalS4ULogon(username)
|
return prepareLocalS4ULogon(logger, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||||
upn, err := lookupPrincipalName(username, domain)
|
upn, err := lookupPrincipalName(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||||
|
|
||||||
upnUtf16, err := windows.UTF16FromString(upn)
|
upnUtf16, err := windows.UTF16FromString(upn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
|
||||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||||
|
|
||||||
usernameUtf16, err := windows.UTF16FromString(username)
|
usernameUtf16, err := windows.UTF16FromString(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// performS4ULogon executes the S4U logon operation
|
// performS4ULogon executes the S4U logon operation
|
||||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||||
var tokenSource tokenSource
|
var tokenSource tokenSource
|
||||||
copy(tokenSource.SourceName[:], "netbird")
|
copy(tokenSource.SourceName[:], "netbird")
|
||||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||||
log.Debugf("AllocateLocallyUniqueId failed")
|
logger.Debugf("AllocateLocallyUniqueId failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
originName := newLsaString("netbird")
|
originName := newLsaString("netbird")
|
||||||
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
|||||||
|
|
||||||
if profile != 0 {
|
if profile != 0 {
|
||||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
|||||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("created S4U %s token for user %s",
|
logger.Debugf("created S4U %s token for user %s",
|
||||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
|||||||
|
|
||||||
// authenticateLocalUser handles authentication for local users
|
// authenticateLocalUser handles authentication for local users
|
||||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
|
||||||
token, err := generateS4UUserToken(username, ".")
|
token, err := generateS4UUserToken(pd.log(), username, ".")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||||
}
|
}
|
||||||
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
|
|||||||
|
|
||||||
// authenticateDomainUser handles authentication for domain users
|
// authenticateDomainUser handles authentication for domain users
|
||||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||||
token, err := generateS4UUserToken(username, domain)
|
token, err := generateS4UUserToken(pd.log(), username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||||
}
|
}
|
||||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(token); err != nil {
|
if err := windows.CloseHandle(token); err != nil {
|
||||||
log.Debugf("close impersonation token: %v", err)
|
pd.log().Debugf("close impersonation token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
|||||||
return cmd, primaryToken, nil
|
return cmd, primaryToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
// createSuCommand creates a command using su - for privilege switching (Windows stub).
|
||||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||||
return nil, fmt.Errorf("su command not available on Windows")
|
return nil, fmt.Errorf("su command not available on Windows")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
|
|||||||
server.SetAllowRootLogin(true)
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
|
|||||||
serverNoJWT.SetAllowRootLogin(true)
|
serverNoJWT.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||||
defer require.NoError(t, serverNoJWT.Stop())
|
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
|
||||||
|
|
||||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
|
|||||||
server.SetAllowRootLogin(true)
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
|
|||||||
server.SetAllowRootLogin(true)
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
|
|||||||
server.UpdateSSHAuth(authConfig)
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
|
|||||||
server.UpdateSSHAuth(authConfig)
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer func() { require.NoError(t, server.Stop()) }()
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
|||||||
return s.allowRemotePortForwarding
|
return s.allowRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
|
||||||
func (s *Server) isPortForwardingEnabled() bool {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTcpipForwardRequest parses the SSH request payload
|
// parseTcpipForwardRequest parses the SSH request payload
|
||||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||||
var payload tcpipForwardMsg
|
var payload tcpipForwardMsg
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
|||||||
sessions = append(sessions, info)
|
sessions = append(sessions, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
|
||||||
for key, connState := range s.connections {
|
for key, connState := range s.connections {
|
||||||
remoteAddr := string(key)
|
remoteAddr := string(key)
|
||||||
if reportedAddrs[remoteAddr] {
|
if reportedAddrs[remoteAddr] {
|
||||||
|
|||||||
@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
func TestServer_NonPtyShellSession(t *testing.T) {
|
||||||
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
require.NoError(t, err, "Should be able to get current user")
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
|
||||||
// Generate host key for server
|
|
||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
allowLocalForwarding bool
|
allowLocalForwarding bool
|
||||||
allowRemoteForwarding bool
|
allowRemoteForwarding bool
|
||||||
expectAllowed bool
|
|
||||||
description string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_local_forwarding",
|
name: "shell_with_local_forwarding_enabled",
|
||||||
allowLocalForwarding: true,
|
allowLocalForwarding: true,
|
||||||
allowRemoteForwarding: false,
|
allowRemoteForwarding: false,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_remote_forwarding",
|
name: "shell_with_remote_forwarding_enabled",
|
||||||
allowLocalForwarding: false,
|
allowLocalForwarding: false,
|
||||||
allowRemoteForwarding: true,
|
allowRemoteForwarding: true,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_both",
|
name: "shell_with_both_forwarding_enabled",
|
||||||
allowLocalForwarding: true,
|
allowLocalForwarding: true,
|
||||||
allowRemoteForwarding: true,
|
allowRemoteForwarding: true,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_denied_without_forwarding",
|
name: "shell_with_forwarding_disabled",
|
||||||
allowLocalForwarding: false,
|
allowLocalForwarding: false,
|
||||||
allowRemoteForwarding: false,
|
allowRemoteForwarding: false,
|
||||||
expectAllowed: false,
|
|
||||||
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
_ = server.Stop()
|
_ = server.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Connect to the server without requesting PTY or command
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
_ = client.Close()
|
_ = client.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Execute a command without PTY - this simulates ssh -T with no command
|
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
|
||||||
// The server should either allow it (port forwarding enabled) or reject it
|
// Should always succeed regardless of port forwarding settings
|
||||||
output, err := client.ExecuteCommand(ctx, "")
|
_, err = client.ExecuteCommand(ctx, "")
|
||||||
if tt.expectAllowed {
|
assert.NoError(t, err, "Non-PTY shell session should be allowed")
|
||||||
// When allowed, the session stays open until cancelled
|
|
||||||
// ExecuteCommand with empty command should return without error
|
|
||||||
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
|
||||||
assert.NotContains(t, output, "port forwarding is disabled",
|
|
||||||
"Output should not contain port forwarding disabled message")
|
|
||||||
} else if err != nil {
|
|
||||||
// When denied, we expect an error message about port forwarding being disabled
|
|
||||||
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
|
||||||
"Should get port forwarding disabled message")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
|||||||
assert.Equal(t, "-Command", args[1])
|
assert.Equal(t, "-Command", args[1])
|
||||||
assert.Equal(t, "echo test", args[2])
|
assert.Equal(t, "echo test", args[2])
|
||||||
} else {
|
} else {
|
||||||
// Test Unix shell behavior
|
|
||||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||||
assert.Equal(t, "/bin/sh", args[0])
|
assert.Equal(t, "/bin/sh", args[0])
|
||||||
assert.Equal(t, "-l", args[1])
|
assert.Equal(t, "-c", args[1])
|
||||||
assert.Equal(t, "-c", args[2])
|
assert.Equal(t, "echo test", args[2])
|
||||||
assert.Equal(t, "echo test", args[3])
|
|
||||||
|
args = server.getShellCommandArgs("/bin/sh", "")
|
||||||
|
assert.Equal(t, "/bin/sh", args[0])
|
||||||
|
assert.Len(t, args, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
|||||||
ptyReq, winCh, isPty := session.Pty()
|
ptyReq, winCh, isPty := session.Pty()
|
||||||
hasCommand := len(session.Command()) > 0
|
hasCommand := len(session.Command()) > 0
|
||||||
|
|
||||||
switch {
|
if isPty && !hasCommand {
|
||||||
case isPty && hasCommand:
|
// ssh <host> - PTY interactive session (login)
|
||||||
// ssh -t <host> <cmd> - Pty command execution
|
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
|
||||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
} else {
|
||||||
case isPty:
|
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
|
||||||
// ssh <host> - Pty interactive session (login)
|
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
|
||||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
|
||||||
case hasCommand:
|
|
||||||
// ssh <host> <cmd> - non-Pty command execution
|
|
||||||
s.handleCommand(logger, session, privilegeResult, nil)
|
|
||||||
default:
|
|
||||||
// ssh -T (or ssh -N) - no PTY, no command
|
|
||||||
s.handleNonInteractiveSession(logger, session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
|
||||||
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
|
||||||
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
|
||||||
s.updateSessionType(session, cmdNonInteractive)
|
|
||||||
|
|
||||||
if !s.isPortForwardingEnabled() {
|
|
||||||
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
|
||||||
logger.Debugf(errWriteSession, err)
|
|
||||||
}
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
<-session.Context().Done()
|
|
||||||
|
|
||||||
if err := session.Exit(0); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
for _, state := range s.sessions {
|
|
||||||
if state.session == session {
|
|
||||||
state.sessionType = sessionType
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handlePty is not supported on JS/WASM
|
// handlePtyLogin is not supported on JS/WASM
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||||
logger.Debugf(errWriteSession, err)
|
logger.Debugf(errWriteSession, err)
|
||||||
|
|||||||
@@ -8,19 +8,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StartTestServer starts the SSH server and returns the address it's listening on.
|
||||||
func StartTestServer(t *testing.T, server *Server) string {
|
func StartTestServer(t *testing.T, server *Server) string {
|
||||||
started := make(chan string, 1)
|
started := make(chan string, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Use port 0 to let the OS assign a free port
|
|
||||||
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
||||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual listening address from the server
|
|
||||||
actualAddr := server.Addr()
|
actualAddr := server.Addr()
|
||||||
if actualAddr == nil {
|
if actualAddr == nil {
|
||||||
errChan <- fmt.Errorf("server started but no listener address available")
|
errChan <- fmt.Errorf("server started but no listener address available")
|
||||||
|
|||||||
@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
|||||||
|
|
||||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||||
// Returns the command and a cleanup function (no-op on Unix).
|
// Returns the command and a cleanup function (no-op on Unix).
|
||||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||||
|
|
||||||
if err := validateUsername(localUser.Username); err != nil {
|
if err := validateUsername(localUser.Username); err != nil {
|
||||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||||
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||||
}
|
}
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
config := ExecutorConfig{
|
config := ExecutorConfig{
|
||||||
UID: uid,
|
UID: uid,
|
||||||
GID: gid,
|
GID: gid,
|
||||||
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
|
|||||||
shell := getUserShell(localUser.Uid)
|
shell := getUserShell(localUser.Uid)
|
||||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||||
|
|
||||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||||
cmd.Dir = localUser.HomeDir
|
cmd.Dir = localUser.HomeDir
|
||||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||||
|
|
||||||
|
|||||||
@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
|
|||||||
|
|
||||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||||
// Returns the command and a cleanup function that must be called after starting the process.
|
// Returns the command and a cleanup function that must be called after starting the process.
|
||||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||||
|
|
||||||
username, _ := s.parseUsername(localUser.Username)
|
username, _ := s.parseUsername(localUser.Username)
|
||||||
if err := validateUsername(username); err != nil {
|
if err := validateUsername(username); err != nil {
|
||||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
return s.createUserSwitchCommand(logger, session, localUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createUserSwitchCommand creates a command with Windows user switching.
|
// createUserSwitchCommand creates a command with Windows user switching.
|
||||||
// Returns the command and a cleanup function that must be called after starting the process.
|
// Returns the command and a cleanup function that must be called after starting the process.
|
||||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
|
|
||||||
shell := getUserShell(localUser.Uid)
|
shell := getUserShell(localUser.Uid)
|
||||||
@@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
|||||||
}
|
}
|
||||||
|
|
||||||
config := WindowsExecutorConfig{
|
config := WindowsExecutorConfig{
|
||||||
Username: username,
|
Username: username,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
WorkingDir: localUser.HomeDir,
|
WorkingDir: localUser.HomeDir,
|
||||||
Shell: shell,
|
Shell: shell,
|
||||||
Command: command,
|
Command: command,
|
||||||
Interactive: interactive || (rawCmd == ""),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dropper := NewPrivilegeDropper()
|
dropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
|||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
if token != 0 {
|
if token != 0 {
|
||||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||||
log.Debugf("close primary token: %v", err)
|
logger.Debugf("close primary token: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||||
commandLine := buildCommandLine(args)
|
commandLine := buildCommandLine(args)
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig
|
|||||||
Pty: ptyConfig,
|
Pty: ptyConfig,
|
||||||
User: userConfig,
|
User: userConfig,
|
||||||
Session: session,
|
Session: session,
|
||||||
Context: ctx,
|
Context: session.Context(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return executeConPtyWithConfig(commandLine, config)
|
return executeConPtyWithConfig(commandLine, config)
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ func (h *eventHandler) listen(ctx context.Context) {
|
|||||||
h.handleNetworksClick()
|
h.handleNetworksClick()
|
||||||
case <-h.client.mNotifications.ClickedCh:
|
case <-h.client.mNotifications.ClickedCh:
|
||||||
h.handleNotificationsClick()
|
h.handleNotificationsClick()
|
||||||
|
case <-systray.TrayOpenedCh:
|
||||||
|
h.client.updateExitNodes()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -341,7 +341,6 @@ func (s *serviceClient) updateExitNodes() {
|
|||||||
log.Errorf("get client: %v", err)
|
log.Errorf("get client: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
exitNodes, err := s.getExitNodes(conn)
|
exitNodes, err := s.getExitNodes(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get exit nodes: %v", err)
|
log.Errorf("get exit nodes: %v", err)
|
||||||
|
|||||||
5
combined/Dockerfile
Normal file
5
combined/Dockerfile
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||||
|
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
COPY netbird-server /go/bin/netbird-server
|
||||||
25
combined/Dockerfile.multistage
Normal file
25
combined/Dockerfile.multistage
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
FROM golang:1.25-bookworm AS builder
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y gcc libc6-dev git && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Build with version info from git (matching goreleaser ldflags)
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build \
|
||||||
|
-ldflags="-s -w \
|
||||||
|
-X github.com/netbirdio/netbird/version.version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev') \
|
||||||
|
-X main.commit=$(git rev-parse --short HEAD 2>/dev/null || echo 'unknown') \
|
||||||
|
-X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ) \
|
||||||
|
-X main.builtBy=docker" \
|
||||||
|
-o netbird-server ./combined
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
|
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||||
|
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
COPY --from=builder /app/netbird-server /go/bin/netbird-server
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user