mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
2 Commits
fix/filter
...
log/conn-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97ad3307dd | ||
|
|
f6cc27d675 |
@@ -1,6 +0,0 @@
|
||||
.env
|
||||
.env.*
|
||||
*.pem
|
||||
*.key
|
||||
*.crt
|
||||
*.p12
|
||||
14
.github/ISSUE_TEMPLATE/config.yml
vendored
14
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,14 +0,0 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Community Support
|
||||
url: https://forum.netbird.io/
|
||||
about: Community support forum
|
||||
- name: Cloud Support
|
||||
url: https://docs.netbird.io/help/report-bug-issues
|
||||
about: Contact us for support
|
||||
- name: Client/Connection Troubleshooting
|
||||
url: https://docs.netbird.io/help/troubleshooting-client
|
||||
about: See our client troubleshooting guide for help addressing common issues
|
||||
- name: Self-host Troubleshooting
|
||||
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||
about: See our self-host troubleshooting guide for help addressing common issues
|
||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -39,11 +39,11 @@ jobs:
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
|
||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,5 +46,6 @@ jobs:
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
||||
98
.github/workflows/golang-test-linux.yml
vendored
98
.github/workflows/golang-test-linux.yml
vendored
@@ -97,16 +97,6 @@ jobs:
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
- name: Build combined
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build combined 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -154,7 +144,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -214,7 +204,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -271,53 +261,6 @@ jobs:
|
||||
-exec 'sudo' \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
|
||||
test_proxy:
|
||||
name: "Proxy / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test -timeout 10m -p 1 ./proxy/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -409,19 +352,12 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
@@ -504,18 +440,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
@@ -596,18 +529,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
|
||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||
skip: go.mod,go.sum
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
51
.github/workflows/pr-title-check.yml
vendored
51
.github/workflows/pr-title-check.yml
vendored
@@ -1,51 +0,0 @@
|
||||
name: PR Title Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
check-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const allowedTags = [
|
||||
'management',
|
||||
'client',
|
||||
'signal',
|
||||
'proxy',
|
||||
'relay',
|
||||
'misc',
|
||||
'infrastructure',
|
||||
'self-hosted',
|
||||
'doc',
|
||||
];
|
||||
|
||||
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||
const match = title.match(pattern);
|
||||
|
||||
if (!match) {
|
||||
core.setFailed(
|
||||
`PR title must start with a tag in brackets.\n` +
|
||||
`Example: [client] fix something\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||
|
||||
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||
if (invalid.length > 0) {
|
||||
core.setFailed(
|
||||
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.1"
|
||||
SIGN_PIPE_VER: "v0.1.0"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
@@ -176,7 +176,6 @@ jobs:
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
@@ -186,19 +185,6 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: Tag and push PR images (amd64 only)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||
run: |
|
||||
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
IMG_NAME="${SRC%%:*}"
|
||||
DST="${IMG_NAME}:${PR_TAG}"
|
||||
echo "Tagging ${SRC} -> ${DST}"
|
||||
docker tag "$SRC" "$DST"
|
||||
docker push "$DST"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,7 +2,6 @@
|
||||
.run
|
||||
*.iml
|
||||
dist/
|
||||
!proxy/web/dist/
|
||||
bin/
|
||||
.env
|
||||
conf.json
|
||||
|
||||
181
.goreleaser.yaml
181
.goreleaser.yaml
@@ -106,26 +106,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-server
|
||||
dir: combined
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-server
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-upload
|
||||
dir: upload-server
|
||||
env: [CGO_ENABLED=0]
|
||||
@@ -140,20 +120,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-proxy
|
||||
dir: proxy/cmd/proxy
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-proxy
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -554,104 +520,6 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
@@ -730,18 +598,6 @@ docker_manifests:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
@@ -819,43 +675,6 @@ docker_manifests:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
@@ -60,8 +60,8 @@
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
|
||||
### Key features
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.3
|
||||
FROM alpine:3.23.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
var (
|
||||
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
|
||||
// EnvKeyNBForceRelay Exported for Android java client
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
|
||||
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||
|
||||
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||
)
|
||||
|
||||
// EnvList wraps a Go map for export to Java
|
||||
|
||||
@@ -3,7 +3,15 @@ package android
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
@@ -108,17 +129,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -137,41 +160,49 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
go urlOpener.OnLoginSuccess()
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
@@ -181,10 +212,22 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
var (
|
||||
exposePin string
|
||||
exposePassword string
|
||||
exposeUserGroups []string
|
||||
exposeDomain string
|
||||
exposeNamePrefix string
|
||||
exposeProtocol string
|
||||
)
|
||||
|
||||
var exposeCmd = &cobra.Command{
|
||||
Use: "expose <port>",
|
||||
Short: "Expose a local port via the NetBird reverse proxy",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: "netbird expose --with-password safe-pass 8080",
|
||||
RunE: exposeFn,
|
||||
}
|
||||
|
||||
func init() {
|
||||
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
|
||||
}
|
||||
|
||||
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||
}
|
||||
if port == 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||
}
|
||||
|
||||
if !isProtocolValid(exposeProtocol) {
|
||||
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
|
||||
}
|
||||
|
||||
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||
return 0, fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||
return 0, fmt.Errorf("user groups cannot be empty")
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func isProtocolValid(exposeProtocol string) bool {
|
||||
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
|
||||
}
|
||||
|
||||
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
log.Errorf("failed initializing log %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = false
|
||||
|
||||
port, err := validateExposeFlags(cmd, args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = true
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
cancel()
|
||||
}()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close daemon connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
protocol, err := toExposeProtocol(exposeProtocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
|
||||
Port: uint32(port),
|
||||
Protocol: protocol,
|
||||
Pin: exposePin,
|
||||
Password: exposePassword,
|
||||
UserGroups: exposeUserGroups,
|
||||
Domain: exposeDomain,
|
||||
NamePrefix: exposeNamePrefix,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return waitForExposeEvents(cmd, ctx, stream)
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http":
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case "https":
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
}
|
||||
|
||||
switch e := event.Event.(type) {
|
||||
case *proto.ExposeServiceEvent_Ready:
|
||||
cmd.Println("Service exposed successfully!")
|
||||
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
|
||||
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
|
||||
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
|
||||
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||
cmd.Printf(" Port: %d\n", port)
|
||||
cmd.Println()
|
||||
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||
for {
|
||||
_, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.Println("\nService stopped.")
|
||||
//nolint:nilerr
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -276,15 +277,18 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
needsLogin := false
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("check login required: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -296,9 +300,23 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -326,7 +344,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -81,15 +80,6 @@ var (
|
||||
Short: "",
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(cmd.Root())
|
||||
|
||||
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||
if !isServiceCmd(cmd) {
|
||||
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,7 +144,6 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
rootCmd.AddCommand(exposeCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
@@ -396,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
@@ -408,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||
func isServiceCmd(cmd *cobra.Command) bool {
|
||||
for c := cmd; c != nil; c = c.Parent() {
|
||||
if c.Name() == "service" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -31,14 +30,6 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
const (
|
||||
// PeerStatusConnected indicates the peer is in connected state.
|
||||
PeerStatusConnected = peer.StatusConnected
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -77,10 +68,6 @@ type Options struct {
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -149,8 +136,6 @@ func New(opts Options) (*Client, error) {
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
@@ -170,7 +155,6 @@ func New(opts Options) (*Client, error) {
|
||||
setupKey: opts.SetupKey,
|
||||
jwtToken: opts.JWTToken,
|
||||
config: config,
|
||||
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -192,17 +176,13 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
|
||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create auth client: %w", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
@@ -355,9 +335,14 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
@@ -365,7 +350,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return c.recorder.GetFullStatus(), nil
|
||||
return recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
|
||||
@@ -83,10 +83,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -181,10 +177,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -285,125 +277,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||
}
|
||||
|
||||
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("create chain: %w", err)
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -168,10 +168,6 @@ type Manager interface {
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -49,10 +48,8 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
@@ -94,10 +91,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
@@ -295,15 +288,7 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
@@ -346,176 +331,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||
return fmt.Errorf("notrack chains not initialized")
|
||||
}
|
||||
|
||||
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||
loopback := []byte{127, 0, 0, 1}
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush notrack rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawPrerouting,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush chain creation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) refreshNoTrackChains() error {
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, c := range chains {
|
||||
if c.Table.Name != tableName {
|
||||
continue
|
||||
}
|
||||
switch c.Name {
|
||||
case chainNameRawOutput:
|
||||
m.notrackOutputChain = c
|
||||
case chainNameRawPrerouting:
|
||||
m.notrackPreroutingChain = c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -483,12 +483,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if nftRule.Handle == 0 {
|
||||
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
}
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
@@ -665,32 +660,13 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||
keys := []string{
|
||||
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||
}
|
||||
for _, key := range keys {
|
||||
rule, ok := r.rules[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||
}
|
||||
delete(r.rules, key)
|
||||
}
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
@@ -952,30 +928,18 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1365,89 +1329,65 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||
// counters will be off until the next successful removal or refresh cycle.
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
// TODO: rollback set counter
|
||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
var merr *multierror.Error
|
||||
newRules := make(map[string]*nftables.Rule)
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||
// preserve existing entries for this chain since we can't verify their state
|
||||
for k, v := range r.rules {
|
||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||
newRules[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
return fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
newRules[string(rule.UserData)] = rule
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
r.rules = newRules
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
@@ -1689,34 +1629,20 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if dnatRule.Handle == 0 {
|
||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if masqRule.Handle == 0 {
|
||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||
if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsFlush {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
|
||||
if merr == nil {
|
||||
@@ -1831,25 +1757,16 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -720,137 +719,3 @@ func deleteWorkTable() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Add a real rule to the kernel
|
||||
ruleKey, err := r.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
firewall.ProtocolTCP,
|
||||
nil,
|
||||
&firewall.Port{Values: []uint16{80}},
|
||||
firewall.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||
})
|
||||
|
||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||
staleKey := "stale-rule-that-does-not-exist"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
}
|
||||
|
||||
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||
|
||||
err = r.refreshRulesMap()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||
|
||||
realRule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "real rule should still exist after refresh")
|
||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||
}
|
||||
|
||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Inject a stale entry with Handle=0
|
||||
staleKey := "stale-route-rule"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
}
|
||||
|
||||
// DeleteRouteRule should not return an error for stale handles
|
||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||
}
|
||||
|
||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "staletest",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
rtr := manager.router
|
||||
|
||||
// First add succeeds
|
||||
err = rtr.AddNatRule(pair)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||
})
|
||||
|
||||
// Corrupt the handle to simulate stale state
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
|
||||
// Adding the same rule again should succeed despite stale handles
|
||||
err = rtr.AddNatRule(pair)
|
||||
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||
|
||||
// Verify rules exist in kernel
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err)
|
||||
|
||||
found := 0
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -11,7 +17,33 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.resetState()
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -23,7 +26,33 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.resetState()
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
|
||||
@@ -115,17 +115,6 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
||||
return t.tombstone.Load()
|
||||
}
|
||||
|
||||
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||
if t.tombstone.Load() {
|
||||
return true
|
||||
}
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||
}
|
||||
|
||||
// SetTombstone safely marks the connection for deletion
|
||||
func (t *TCPConnTrack) SetTombstone() {
|
||||
t.tombstone.Store(true)
|
||||
@@ -180,7 +169,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists && !conn.IsSupersededBy(flags) {
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
@@ -252,7 +241,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.IsSupersededBy(flags) {
|
||||
if !exists || conn.IsTombstone() {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -485,261 +485,6 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and gracefully close a connection (server-initiated close)
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Server sends FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Client sends FIN-ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
|
||||
// Server sends final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Connection should be tombstoned
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn, "old connection should still be in map")
|
||||
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||
|
||||
// Now reuse the same port for a new connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
// The old tombstoned entry should be replaced with a new one
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn, "new connection should exist")
|
||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||
|
||||
// SYN-ACK for the new connection should be valid
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
|
||||
// Data transfer should work
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||
require.True(t, valid, "data should be allowed on new connection")
|
||||
})
|
||||
|
||||
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and RST a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||
|
||||
// Reuse the same port
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn)
|
||||
require.False(t, newConn.IsTombstone())
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||
})
|
||||
|
||||
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := srcIP
|
||||
serverIP := dstIP
|
||||
clientPort := srcPort
|
||||
serverPort := dstPort
|
||||
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||
|
||||
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Server-initiated close to reach Closed/tombstoned:
|
||||
// Server FIN (opposite dir) → CloseWait
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
// Client FIN-ACK (same dir as conn) → LastAck
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
require.True(t, conn.IsTombstone())
|
||||
|
||||
// New inbound connection on same ports
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn)
|
||||
require.False(t, newConn.IsTombstone())
|
||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||
|
||||
// Complete handshake: server SYN-ACK, then client ACK
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
})
|
||||
|
||||
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.True(t, conn.IsTombstone())
|
||||
|
||||
// Late ACK should be rejected (tombstoned)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||
|
||||
// Late outbound ACK should not create a new connection (not a SYN)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Active close: client (outbound initiator) sends FIN first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Server ACKs the FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Server sends its own FIN
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||
|
||||
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn, "new connection should exist")
|
||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||
|
||||
// SYN-ACK for new connection should be valid
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
})
|
||||
|
||||
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish outbound connection and close via active close → TIME-WAIT
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||
|
||||
// Simulate what the filter does next: TrackInbound via the normal path
|
||||
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||
newConn := tracker.connections[invertedKey]
|
||||
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||
require.False(t, newConn.IsTombstone())
|
||||
})
|
||||
|
||||
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and active close → TIME-WAIT
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -13,13 +12,11 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
@@ -27,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -93,7 +89,6 @@ type Manager struct {
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -234,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
@@ -486,15 +480,11 @@ func (m *Manager) addRouteFiltering(
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
ruleID := uuid.New().String()
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: string(ruleKey),
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
@@ -509,7 +499,6 @@ func (m *Manager) addRouteFiltering(
|
||||
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.routeRulesMap[ruleKey] = &rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
@@ -526,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == string(ruleKey)
|
||||
return r.id == ruleID
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
delete(m.routeRulesMap, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -586,48 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
maps.Clear(m.outgoingRules)
|
||||
maps.Clear(m.incomingDenyRules)
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
|
||||
@@ -1,376 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.1.0/24"),
|
||||
netip.MustParsePrefix("100.64.2.0/24"),
|
||||
}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
// These should be the same (idempotent) like nftables/iptables implementations
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||
// different parameters get distinct IDs.
|
||||
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||
"Different rules should have different IDs")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||
// rule during a network map update does not disrupt existing traffic.
|
||||
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.True(t, passAfter,
|
||||
"Traffic should still pass after rule update - no gap should occur")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||
// returns the same rule without duplicating.
|
||||
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call blockInvalidRouted directly multiple times
|
||||
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule3)
|
||||
|
||||
// All should return the same rule
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||
|
||||
// Should have exactly 1 route rule
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||
|
||||
// Verify the rule blocks traffic to the WG network
|
||||
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||
}
|
||||
|
||||
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||
// EnableRouting multiple times (as happens on each route update) does not
|
||||
// accumulate duplicate block rules in the routeRules slice.
|
||||
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount,
|
||||
"Repeated EnableRouting should not accumulate block rules")
|
||||
}
|
||||
|
||||
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||
// rule multiple times does not create duplicate entries.
|
||||
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||
}
|
||||
|
||||
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||
// after adding it multiple times works correctly.
|
||||
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||
}
|
||||
|
||||
func setupTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
}
|
||||
@@ -263,158 +263,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||
}
|
||||
|
||||
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||
// IP are stored in separate maps and don't interfere with each other.
|
||||
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
require.False(t, allowExists, "Allow rules should be empty")
|
||||
}
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -18,18 +16,9 @@ const (
|
||||
maxBatchSize = 1024 * 16
|
||||
maxMessageSize = 1024 * 2
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
defaultLogChanSize = 1000
|
||||
logChannelSize = 1000
|
||||
)
|
||||
|
||||
func getLogChannelSize() int {
|
||||
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultLogChanSize
|
||||
}
|
||||
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
@@ -80,7 +69,7 @@ type Logger struct {
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||
msgChannel: make(chan logMessage, logChannelSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
|
||||
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
|
||||
@@ -5,18 +5,20 @@ package configurer
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||
if err != nil {
|
||||
_ = uapiSock.Close()
|
||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -54,14 +54,6 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
return &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
@@ -566,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(val)
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
|
||||
@@ -29,9 +29,8 @@ type PacketFilter interface {
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
filter PacketFilter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// newDeviceFilter constructor function
|
||||
@@ -41,20 +40,6 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying tun device exactly once.
|
||||
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||
// and multiple code paths can trigger Close on the same device.
|
||||
func (d *FilteredDevice) Close() error {
|
||||
var err error
|
||||
d.closeOnce.Do(func() {
|
||||
err = d.Device.Close()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read wraps read method with filtering feature
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
|
||||
@@ -79,12 +79,10 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
if cErr := tunIface.Close(); cErr != nil {
|
||||
log.Debugf("failed to close tun device: %v", cErr)
|
||||
}
|
||||
_ = tunIface.Close()
|
||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
@@ -51,7 +50,6 @@ func ValidateMTU(mtu uint16) error {
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
Free() error
|
||||
}
|
||||
|
||||
@@ -82,12 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||
func (w *WGIface) GetProxyPort() uint16 {
|
||||
return w.wgProxyFactory.GetProxyPort()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
@@ -229,10 +221,6 @@ func (w *WGIface) Close() error {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
if nbnetstack.IsEnabled() {
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
if err := w.waitUntilRemoved(); err != nil {
|
||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||
if err := w.Destroy(); err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return t.tundev, tunNet, nil
|
||||
return nsTunDev, tunNet, nil
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Close() error {
|
||||
|
||||
@@ -22,11 +22,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// cgnatPrefix is the RFC 6598 Carrier-Grade NAT range (100.64.0.0/10).
|
||||
// Addresses in this range are used by CNI plugins (Cilium, Calico, etc.) for pod networking
|
||||
// and are not suitable for direct peer-to-peer connectivity between hosts.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
||||
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
@@ -180,15 +175,6 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
// Filter addresses in the RFC 6598 CGNAT range (100.64.0.0/10) that are not part of the
|
||||
// NetBird WireGuard network. These addresses are commonly assigned by Kubernetes CNI plugins
|
||||
// (Cilium, Calico, etc.) for pod networking and are not routable between hosts.
|
||||
if cgnatPrefix.Contains(a) && !u.address.Network.Contains(a) {
|
||||
u.addrCache.Store(addr.String(), true)
|
||||
log.Infof("Address %s is in the CGNAT range (%s), likely a CNI pod address, refusing to write", addr, cgnatPrefix)
|
||||
return fmt.Errorf("address %s is in the CGNAT range (%s), refusing to write", addr, cgnatPrefix)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
|
||||
@@ -114,21 +114,34 @@ func (p *ProxyBind) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to start package redirection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = ep
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert endpoint address: %v", err)
|
||||
} else {
|
||||
p.wgCurrentUsed = ep
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, errors.New("nil address")
|
||||
}
|
||||
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
@@ -212,16 +225,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid address")
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -24,10 +26,13 @@ const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
proxyPort int
|
||||
mtu uint16
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
@@ -35,8 +40,7 @@ type WGEBPFProxy struct {
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConnIPv4 net.PacketConn
|
||||
rawConnIPv6 net.PacketConn
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
@@ -58,39 +62,23 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
proxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.proxyPort = proxyPort
|
||||
|
||||
// Prepare IPv4 raw socket (required)
|
||||
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare IPv6 raw socket (optional)
|
||||
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||
}
|
||||
if p.rawConnIPv6 != nil {
|
||||
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: proxyPort,
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
@@ -106,7 +94,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -147,25 +135,12 @@ func (p *WGEBPFProxy) Free() error {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if p.rawConnIPv4 != nil {
|
||||
if err := p.rawConnIPv4.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.rawConnIPv6 != nil {
|
||||
if err := p.rawConnIPv6.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy listening port.
|
||||
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||
return uint16(p.proxyPort)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@@ -241,3 +216,34 @@ generatePort:
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,89 +10,12 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
var (
|
||||
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||
|
||||
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||
localHostNetIPv6 = net.ParseIP("::1")
|
||||
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
)
|
||||
|
||||
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||
type PacketHeaders struct {
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH *layers.UDP
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr net.IP
|
||||
isIPv4 bool
|
||||
}
|
||||
|
||||
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
var localHostAddr net.IP
|
||||
var isIPv4 bool
|
||||
|
||||
// Check if source address is IPv4 or IPv6
|
||||
if endpoint.IP.To4() != nil {
|
||||
// IPv4 path
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPv4,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
localHostAddr = localHostNetIPv4
|
||||
isIPv4 = true
|
||||
} else {
|
||||
// IPv6 path
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPv6,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
localHostAddr = localHostNetIPv6
|
||||
isIPv4 = false
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpoint.Port),
|
||||
DstPort: layers.UDPPort(localWGListenPort),
|
||||
}
|
||||
|
||||
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return &PacketHeaders{
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
isIPv4: isIPv4,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
@@ -101,10 +24,8 @@ type ProxyWrapper struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
headers *PacketHeaders
|
||||
headerCurrentUsed *PacketHeaders
|
||||
rawConn net.PacketConn
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
@@ -120,32 +41,15 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
|
||||
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create packet sender: %w", err)
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
return errIPv6ConnNotAvailable
|
||||
}
|
||||
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
return errIPv4ConnNotAvailable
|
||||
}
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
p.headers = headers
|
||||
p.rawConn = p.selectRawConn(headers)
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
@@ -164,8 +68,7 @@ func (p *ProxyWrapper) Work() {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.headerCurrentUsed = p.headers
|
||||
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
@@ -188,32 +91,12 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
if endpoint == nil || endpoint.IP == nil {
|
||||
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||
return
|
||||
}
|
||||
|
||||
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create packet headers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
log.Error(errIPv6ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
log.Error(errIPv4ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.headerCurrentUsed = header
|
||||
p.rawConn = p.selectRawConn(header)
|
||||
if endpoint != nil && endpoint.IP != nil {
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
@@ -255,7 +138,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -281,29 +164,3 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||
defer func() {
|
||||
if err := header.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
|
||||
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||
if header.isIPv4 {
|
||||
return p.wgeBPFProxy.rawConnIPv4
|
||||
}
|
||||
return p.wgeBPFProxy.rawConnIPv6
|
||||
}
|
||||
|
||||
@@ -54,14 +54,6 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||
if w.ebpfProxy == nil {
|
||||
return 0
|
||||
}
|
||||
return w.ebpfProxy.GetProxyPort()
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
|
||||
@@ -24,11 +24,6 @@ func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
}
|
||||
|
||||
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||
func (w *USPFactory) GetProxyPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,87 +8,43 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||
}
|
||||
|
||||
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||
}
|
||||
|
||||
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||
if isIPv4 {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||
}
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
@@ -1,353 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return addr1.String() == addr2.String()
|
||||
}
|
||||
|
||||
// Compare IP and Port, ignoring zone
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||
wgPort := 51853
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||
// It verifies that:
|
||||
// 1. Initial traffic from relay connection works
|
||||
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||
// 3. Multiple packets are correctly redirected with the new source address
|
||||
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener4.Close()
|
||||
|
||||
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener6.Close()
|
||||
|
||||
// Determine which listener to use based on the NetBird address IP version
|
||||
// (this is where initial traffic will come from before RedirectAs is called)
|
||||
var wgListener *net.UDPConn
|
||||
if p2pEndpoint.IP.To4() == nil {
|
||||
wgListener = wgListener6
|
||||
} else {
|
||||
wgListener = wgListener4
|
||||
}
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0, // Random port
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
// Add TURN connection to proxy
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the proxy
|
||||
proxy.Work()
|
||||
|
||||
// Phase 1: Test initial relay traffic
|
||||
msgFromRelay := []byte("hello from relay")
|
||||
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write to relay server: %v", err)
|
||||
}
|
||||
|
||||
// Set read deadline to avoid hanging
|
||||
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _, err := wgListener4.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msgFromRelay) {
|
||||
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||
}
|
||||
|
||||
// Phase 2: Redirect to p2p endpoint
|
||||
proxy.RedirectAs(p2pEndpoint)
|
||||
|
||||
// Give the proxy a moment to process the redirect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Phase 3: Test redirected traffic
|
||||
redirectedMessages := [][]byte{
|
||||
[]byte("redirected message 1"),
|
||||
[]byte("redirected message 2"),
|
||||
[]byte("redirected message 3"),
|
||||
}
|
||||
|
||||
for i, msg := range redirectedMessages {
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
// Verify message content
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||
}
|
||||
|
||||
// Verify source address matches p2p endpoint (this is the key test)
|
||||
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||
t.Errorf("message %d: expected source address %s, got %s",
|
||||
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listener
|
||||
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener.Close()
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
// Test switching between multiple endpoints - using addresses in local subnet
|
||||
endpoints := []*net.UDPAddr{
|
||||
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||
}
|
||||
|
||||
for i, endpoint := range endpoints {
|
||||
proxy.RedirectAs(endpoint)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg := []byte("test message")
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||
}
|
||||
|
||||
if !compareUDPAddr(srcAddr, endpoint) {
|
||||
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||
i, endpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
|
||||
@@ -19,56 +19,37 @@ var (
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddrV4 = &net.IPAddr{
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
localHostNetIPAddrV6 = &net.IPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr *net.IPAddr
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
// Create only the raw socket for the address family we need
|
||||
var rawSocket net.PacketConn
|
||||
var err error
|
||||
var localHostAddr *net.IPAddr
|
||||
|
||||
if srcAddr.IP.To4() != nil {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
localHostAddr = localHostNetIPAddrV4
|
||||
} else {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
localHostAddr = localHostNetIPAddrV6
|
||||
}
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
}
|
||||
|
||||
return f, nil
|
||||
@@ -91,7 +72,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
@@ -99,40 +80,19 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
|
||||
// Check if source IP is IPv4 or IPv6
|
||||
if srcAddr.IP.To4() != nil {
|
||||
// IPv4
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPAddrV4.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
} else {
|
||||
// IPv6
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPAddrV6.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
@@ -189,212 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
}
|
||||
|
||||
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||
}
|
||||
|
||||
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: add deny and accept rules
|
||||
networkMap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap1, false)
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||
|
||||
// Second update: remove the deny rule, keep only accept
|
||||
networkMap2 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap2, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Should have 1 rule after removing deny rule")
|
||||
|
||||
// Third update: remove all rules
|
||||
networkMap3 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap3, false)
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||
"Should have 0 rules after removing all rules")
|
||||
}
|
||||
|
||||
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: accept rule
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Second update: change to deny (same IP/port/proto, different action)
|
||||
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Changing action should result in exactly 1 rule, not 2")
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -1,499 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
mutex sync.RWMutex
|
||||
client *mgm.GrpcClient
|
||||
config *profilemanager.Config
|
||||
privateKey wgtypes.Key
|
||||
mgmURL *url.URL
|
||||
mgmTLSEnabled bool
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth instance that manages authentication flows
|
||||
// It establishes a connection to the management server that will be reused for all operations
|
||||
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||
// Validate WireGuard private key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine TLS setting based on URL scheme
|
||||
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
return &Auth{
|
||||
client: mgmClient,
|
||||
config: config,
|
||||
privateKey: myPrivateKey,
|
||||
mgmURL: mgmURL,
|
||||
mgmTLSEnabled: mgmTLSEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the management client connection
|
||||
func (a *Auth) Close() error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.client == nil {
|
||||
return nil
|
||||
}
|
||||
return a.client.Close()
|
||||
}
|
||||
|
||||
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||
var supportsSSO bool
|
||||
|
||||
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
// Try PKCE flow first
|
||||
_, err := a.getPKCEFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if PKCE is not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// PKCE not supported, try Device flow
|
||||
_, err = a.getDeviceFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Device flow is also not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// Neither PKCE nor Device flow is supported
|
||||
supportsSSO = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Device flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
}
|
||||
|
||||
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
})
|
||||
|
||||
return supportsSSO, err
|
||||
}
|
||||
|
||||
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||
// This avoids creating a new connection to the management server
|
||||
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||
var flow OAuthFlow
|
||||
var err error
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
if forceDeviceAuth {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
|
||||
// Try PKCE flow first
|
||||
flow, err = a.getPKCEFlow(client)
|
||||
if err != nil {
|
||||
// If PKCE not supported, try Device flow
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return flow, err
|
||||
}
|
||||
|
||||
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
needsLogin = false
|
||||
return err
|
||||
})
|
||||
|
||||
return needsLogin, err
|
||||
}
|
||||
|
||||
// Login attempts to log in or register the client with the management server
|
||||
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
|
||||
isAuthError = false
|
||||
return nil
|
||||
})
|
||||
|
||||
return err, isAuthError
|
||||
}
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
ClientCertPair: a.config.ClientCertKeyPair,
|
||||
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||
}
|
||||
|
||||
if err := validatePKCEConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
}
|
||||
|
||||
// Keep compatibility with older management versions
|
||||
if config.Scope == "" {
|
||||
config.Scope = "openid"
|
||||
}
|
||||
|
||||
if err := validateDeviceAuthConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// reconnect closes the current connection and creates a new one
|
||||
// It checks if the brokenClient is still the current client before reconnecting
|
||||
// to avoid multiple threads reconnecting unnecessarily
|
||||
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||
if a.client != brokenClient {
|
||||
log.Debugf("client already reconnected by another thread, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new connection FIRST, before closing the old one
|
||||
// This ensures a.client is never nil, preventing panics in other threads
|
||||
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||
// Keep the old client if reconnection fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Close old connection AFTER new one is successfully created
|
||||
oldClient := a.client
|
||||
a.client = mgmClient
|
||||
|
||||
if oldClient != nil {
|
||||
if err := oldClient.Close(); err != nil {
|
||||
log.Debugf("error closing old connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// These error codes indicate connection issues
|
||||
return s.Code() == codes.Unavailable ||
|
||||
s.Code() == codes.DeadlineExceeded ||
|
||||
s.Code() == codes.Canceled ||
|
||||
s.Code() == codes.Internal
|
||||
}
|
||||
|
||||
// withRetry wraps an operation with exponential backoff retry logic
|
||||
// It automatically reconnects on connection errors
|
||||
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||
backoffSettings := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 2 * time.Minute,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
backoffSettings.Reset()
|
||||
|
||||
return backoff.RetryNotify(
|
||||
func() error {
|
||||
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||
a.mutex.RLock()
|
||||
currentClient := a.client
|
||||
a.mutex.RUnlock()
|
||||
|
||||
if currentClient == nil {
|
||||
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||
}
|
||||
|
||||
// Execute operation with the captured client
|
||||
err := operation(currentClient)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||
if isConnectionError(err) {
|
||||
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||
|
||||
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||
return reconnectErr
|
||||
}
|
||||
// Return the original error to trigger retry with the new connection
|
||||
return err
|
||||
}
|
||||
|
||||
// For authentication errors, don't retry
|
||||
if isAuthenticationError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
backoff.WithContext(backoffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||
func isAuthenticationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||
func isPermissionDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
return isAuthenticationError(err)
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
@@ -25,56 +26,12 @@ const (
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig DeviceAuthProviderConfig
|
||||
HTTPClient HTTPClient
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
@@ -100,7 +57,7 @@ type TokenRequestResponse struct {
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -132,11 +89,6 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the device authorization flow
|
||||
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||
d.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
@@ -247,22 +199,14 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -113,19 +115,18 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
@@ -279,19 +280,18 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -86,33 +87,19 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
if hint != "" {
|
||||
pkceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return pkceFlowInfo, nil
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
@@ -127,9 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
if hint != "" {
|
||||
deviceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return deviceFlowInfo, nil
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
@@ -34,67 +35,17 @@ const (
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validatePKCEConfig validates PKCE provider configuration
|
||||
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig PKCEAuthProviderConfig
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
@@ -173,21 +124,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||
p.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@@ -198,7 +138,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -209,8 +149,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
@@ -49,7 +50,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := PKCEAuthProviderConfig{
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
@@ -93,7 +95,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := PKCEAuthProviderConfig{
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -245,7 +244,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
@@ -331,11 +330,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var scanDir = "/var/run/netbird"
|
||||
|
||||
// setScanDir overrides the scan directory (used by tests).
|
||||
func setScanDir(dir string) {
|
||||
scanDir = dir
|
||||
}
|
||||
|
||||
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||
// mismatch between the netbird@.service template (which places the socket under
|
||||
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
if !strings.HasPrefix(addr, "unix://") {
|
||||
return addr
|
||||
}
|
||||
|
||||
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||
if _, err := os.Stat(sockPath); err == nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(scanDir)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
var found []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(e.Name(), ".sock") {
|
||||
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
switch len(found) {
|
||||
case 1:
|
||||
resolved := "unix://" + found[0]
|
||||
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||
return resolved
|
||||
case 0:
|
||||
return addr
|
||||
default:
|
||||
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||
return addr
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build windows || ios || android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
return addr
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createSockFile creates a regular file with a .sock extension.
|
||||
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||
func createSockFile(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
sock := filepath.Join(tmp, "netbird.sock")
|
||||
createSockFile(t, sock)
|
||||
|
||||
addr := "unix://" + sock
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
// Default socket does not exist
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
// Create a scan dir with one socket
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
instanceSock := filepath.Join(sd, "main.sock")
|
||||
createSockFile(t, instanceSock)
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
expected := "unix://" + instanceSock
|
||||
if got != expected {
|
||||
t.Errorf("expected %s, got %s", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||
addr := "tcp://127.0.0.1:41731"
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
136
client/internal/device_auth.go
Normal file
136
client/internal/device_auth.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -112,54 +112,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD exact match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD subdomain match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD wildcard match",
|
||||
handlerDomain: "*.example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "two letter domain labels",
|
||||
handlerDomain: "a.b.",
|
||||
queryDomain: "a.b.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain with subdomain match",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "sub.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -9,13 +9,9 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@@ -24,7 +20,6 @@ import (
|
||||
|
||||
const (
|
||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
@@ -38,22 +33,11 @@ const (
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
|
||||
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||
maxDomainsPerResolverEntry = 50
|
||||
|
||||
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||
maxDomainBytesPerResolverEntry = 1500
|
||||
)
|
||||
|
||||
type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
systemDNSSettings SystemDNSSettings
|
||||
|
||||
mu sync.RWMutex
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager() (*systemConfigurator, error) {
|
||||
@@ -95,23 +79,28 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old match keys: %v", err)
|
||||
}
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing match domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(matchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old search keys: %v", err)
|
||||
}
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing search domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(searchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
@@ -155,7 +144,8 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
|
||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
if len(s.createdKeys) == 0 {
|
||||
return s.discoverExistingKeys()
|
||||
// return defaults for startup calls
|
||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.createdKeys))
|
||||
@@ -165,47 +155,6 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||
dnsKeys, err := getSystemDNSKeys()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get system DNS keys: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||
if strings.Contains(dnsKeys, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||
for i := 0; ; i++ {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
if !strings.Contains(dnsKeys, key) {
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// getSystemDNSKeys gets all DNS keys
|
||||
func getSystemDNSKeys() (string, error) {
|
||||
command := "list .*DNS\nquit\n"
|
||||
out, err := runSystemConfigCommand(command)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
@@ -230,11 +179,12 @@ func (s *systemConfigurator) addLocalDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add local dns state: %w", err)
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.createdKeys[localKey] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -268,7 +218,6 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
}
|
||||
|
||||
var dnsSettings SystemDNSSettings
|
||||
var serverAddresses []netip.Addr
|
||||
inSearchDomainsArray := false
|
||||
inServerAddressesArray := false
|
||||
|
||||
@@ -295,12 +244,9 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||
ip = ip.Unmap()
|
||||
serverAddresses = append(serverAddresses, ip)
|
||||
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip
|
||||
}
|
||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip.Unmap()
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -312,90 +258,31 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = DefaultPort
|
||||
|
||||
s.mu.Lock()
|
||||
s.origNameservers = serverAddresses
|
||||
s.mu.Unlock()
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.origNameservers)
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||
if len(domains) == 0 {
|
||||
return nil
|
||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
var batches [][]string
|
||||
var current []string
|
||||
currentBytes := 0
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
for _, d := range domains {
|
||||
domainLen := len(d)
|
||||
newBytes := currentBytes + domainLen
|
||||
if currentBytes > 0 {
|
||||
newBytes++ // space separator
|
||||
}
|
||||
|
||||
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||
batches = append(batches, current)
|
||||
current = nil
|
||||
currentBytes = 0
|
||||
}
|
||||
|
||||
current = append(current, d)
|
||||
if currentBytes > 0 {
|
||||
currentBytes += 1 + domainLen
|
||||
} else {
|
||||
currentBytes = domainLen
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
batches = append(batches, current)
|
||||
}
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
// removeKeysContaining removes all created keys that contain the given substring.
|
||||
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||
var toRemove []string
|
||||
for key := range s.createdKeys {
|
||||
if strings.Contains(key, suffix) {
|
||||
toRemove = append(toRemove, key)
|
||||
}
|
||||
}
|
||||
var multiErr *multierror.Error
|
||||
for _, key := range toRemove {
|
||||
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(multiErr)
|
||||
}
|
||||
|
||||
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
|
||||
for i, batch := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
domainsStr := strings.Join(batch, " ")
|
||||
|
||||
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||
}
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
}
|
||||
|
||||
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||
s.createdKeys[key] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -458,6 +345,7 @@ func (s *systemConfigurator) flushDNSCache() error {
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -52,22 +49,17 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
// Collect all created keys for cleanup verification
|
||||
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||
for key := range configurator.createdKeys {
|
||||
createdKeys = append(createdKeys, key)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
_ = removeTestDNSKey(localKey)
|
||||
}()
|
||||
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
@@ -91,223 +83,13 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||
func generateShortDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
label := ""
|
||||
n := i
|
||||
for {
|
||||
label = string(rune('a'+n%26)) + label
|
||||
n = n/26 - 1
|
||||
if n < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
domains = append(domains, label+".com")
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||
func generateLongDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||
t.Helper()
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||
out, err := cmd.Output()
|
||||
require.NoError(t, err, "scutil show should succeed")
|
||||
|
||||
var domains []string
|
||||
inArray := false
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||
inArray = true
|
||||
continue
|
||||
}
|
||||
if inArray {
|
||||
if line == "}" {
|
||||
break
|
||||
}
|
||||
// lines look like: "0 : a.com"
|
||||
parts := strings.SplitN(line, " : ", 2)
|
||||
if len(parts) == 2 {
|
||||
domains = append(domains, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NoError(t, scanner.Err())
|
||||
return domains
|
||||
}
|
||||
|
||||
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expectedCount int
|
||||
checkAllPresent bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
domains: nil,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "under_limit",
|
||||
domains: generateShortDomains(10),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "at_element_limit",
|
||||
domains: generateShortDomains(50),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "over_element_limit",
|
||||
domains: generateShortDomains(51),
|
||||
expectedCount: 2,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "triple_element_limit",
|
||||
domains: generateShortDomains(150),
|
||||
expectedCount: 3,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "long_domains_hit_byte_limit",
|
||||
domains: generateLongDomains(50),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_short_domains",
|
||||
domains: generateShortDomains(500),
|
||||
expectedCount: 10,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_long_domains",
|
||||
domains: generateLongDomains(500),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
batches := splitDomainsIntoBatches(tc.domains)
|
||||
|
||||
if tc.expectedCount > 0 {
|
||||
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||
}
|
||||
|
||||
// Verify each batch respects limits
|
||||
for i, batch := range batches {
|
||||
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||
"batch %d exceeds element limit", i)
|
||||
|
||||
totalBytes := 0
|
||||
for j, d := range batch {
|
||||
if j > 0 {
|
||||
totalBytes++
|
||||
}
|
||||
totalBytes += len(d)
|
||||
}
|
||||
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||
}
|
||||
|
||||
if tc.checkAllPresent {
|
||||
var all []string
|
||||
for _, batch := range batches {
|
||||
all = append(all, batch...)
|
||||
}
|
||||
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||
// and verifies all domains are readable across multiple scutil keys.
|
||||
func TestMatchDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
count int
|
||||
generator func(int) []string
|
||||
}{
|
||||
{"short_10", 10, generateShortDomains},
|
||||
{"short_50", 50, generateShortDomains},
|
||||
{"short_100", 100, generateShortDomains},
|
||||
{"short_200", 200, generateShortDomains},
|
||||
{"short_500", 500, generateShortDomains},
|
||||
{"long_10", 10, generateLongDomains},
|
||||
{"long_50", 50, generateLongDomains},
|
||||
{"long_100", 100, generateLongDomains},
|
||||
{"long_200", 200, generateLongDomains},
|
||||
{"long_500", 500, generateLongDomains},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for key := range configurator.createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
domains := tc.generator(tc.count)
|
||||
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||
|
||||
// Read back all domains from all batched keys
|
||||
var got []string
|
||||
for i := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "key %s should exist", key)
|
||||
|
||||
got = append(got, readDomainsFromKey(t, key)...)
|
||||
}
|
||||
|
||||
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||
assert.Equal(t, domains, got, "domains should match in order")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
@@ -327,169 +109,3 @@ func removeTestDNSKey(key string) error {
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func TestGetOriginalNameservers(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
origNameservers: []netip.Addr{
|
||||
netip.MustParseAddr("8.8.8.8"),
|
||||
netip.MustParseAddr("1.1.1.1"),
|
||||
},
|
||||
}
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
assert.Len(t, servers, 2)
|
||||
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||
}
|
||||
|
||||
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
|
||||
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||
|
||||
for _, server := range servers {
|
||||
assert.True(t, server.IsValid(), "server address should be valid")
|
||||
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||
}
|
||||
|
||||
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||
}
|
||||
|
||||
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for key := range configurator.createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
// Also clean up old-format keys and local key in case they exist
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
}
|
||||
|
||||
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
routeAll bool
|
||||
}{
|
||||
{"routeall_false", false},
|
||||
{"routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
for _, srv := range initialServers {
|
||||
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.routeAll,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialRoute bool
|
||||
}{
|
||||
{"start_with_routeall_false", false},
|
||||
{"start_with_routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.initialRoute,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
// First apply
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle RouteAll
|
||||
config.RouteAll = !tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle back
|
||||
config.RouteAll = tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
for _, srv := range servers {
|
||||
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,8 +42,6 @@ const (
|
||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
@@ -200,11 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
// Update count even on error to ensure cleanup covers partially created rules
|
||||
r.nrptEntryCount = count
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
@@ -242,33 +239,23 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
|
||||
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||
ruleIndex := 0
|
||||
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||
end := i + nrptMaxDomainsPerRule
|
||||
if end > len(domains) {
|
||||
end = len(domains)
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
batchDomains := domains[i:end]
|
||||
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||
}
|
||||
|
||||
// Increment immediately so the caller's cleanup path knows about this rule
|
||||
ruleIndex++
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
|
||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
@@ -277,8 +264,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
|
||||
return ruleIndex, nil
|
||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
||||
return len(domains), nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
@@ -38,60 +37,51 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
domains125 := make([]DomainConfig, 125)
|
||||
for i := 0; i < 125; i++ {
|
||||
domains125[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config125 := HostDNSConfig{
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains125,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config125, nil)
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify 3 NRPT rules exist
|
||||
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
|
||||
for i := 0; i < 3; i++ {
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
|
||||
domains75 := make([]DomainConfig, 75)
|
||||
for i := 0; i < 75; i++ {
|
||||
domains75[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config75 := HostDNSConfig{
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains75,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config75, nil)
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 NRPT rules exist
|
||||
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify rule 2 is cleaned up
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
@@ -107,106 +97,6 @@ func registryKeyExists(path string) (bool, error) {
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
// Clean up more entries to account for batching tests with many domains
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
|
||||
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
|
||||
func TestNRPTDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domainCount int
|
||||
expectedRuleCount int
|
||||
}{
|
||||
{
|
||||
name: "Less than 50 domains (single rule)",
|
||||
domainCount: 30,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Exactly 50 domains (single rule)",
|
||||
domainCount: 50,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "51 domains (two rules)",
|
||||
domainCount: 51,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "100 domains (two rules)",
|
||||
domainCount: 100,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "125 domains (three rules: 50+50+25)",
|
||||
domainCount: 125,
|
||||
expectedRuleCount: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Clean up before each subtest
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
// Generate domains
|
||||
domains := make([]DomainConfig, tc.domainCount)
|
||||
for i := 0; i < tc.domainCount; i++ {
|
||||
domains[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains,
|
||||
}
|
||||
|
||||
err := cfg.applyDNSConfig(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that exactly expectedRuleCount rules were created
|
||||
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||
|
||||
// Verify all expected rules exist
|
||||
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||
}
|
||||
|
||||
// Verify no extra rules were created
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
}
|
||||
}
|
||||
|
||||
// Flow receiver domain is intentionally excluded from caching.
|
||||
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
||||
// causes TLS certificate verification failures on reconnect.
|
||||
if serverDomains.Flow != "" {
|
||||
domains = append(domains, serverDomains.Flow)
|
||||
}
|
||||
|
||||
for _, stun := range serverDomains.Stuns {
|
||||
if stun != "" {
|
||||
|
||||
@@ -391,8 +391,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
}
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
||||
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Flow: "github.com",
|
||||
}
|
||||
@@ -401,10 +400,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
|
||||
|
||||
finalDomains := resolver.GetCachedDomains()
|
||||
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
||||
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
|
||||
|
||||
domainStrings := make([]string, len(finalDomains))
|
||||
for i, d := range finalDomains {
|
||||
@@ -413,5 +412,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
assert.Contains(t, domainStrings, "example.org")
|
||||
assert.Contains(t, domainStrings, "google.com")
|
||||
assert.Contains(t, domainStrings, "cloudflare.com")
|
||||
assert.NotContains(t, domainStrings, "github.com")
|
||||
assert.Contains(t, domainStrings, "github.com")
|
||||
}
|
||||
|
||||
@@ -84,18 +84,3 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// EndBatch mock implementation of EndBatch from Server interface
|
||||
func (m *MockServer) EndBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// CancelBatch mock implementation of CancelBatch from Server interface
|
||||
func (m *MockServer) CancelBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -29,8 +27,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
@@ -45,9 +41,6 @@ type IosDnsManager interface {
|
||||
type Server interface {
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
BeginBatch()
|
||||
EndBatch()
|
||||
CancelBatch()
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() netip.Addr
|
||||
@@ -90,7 +83,6 @@ type DefaultServer struct {
|
||||
currentConfigHash uint64
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
batchMode bool
|
||||
|
||||
mgmtCacheResolver *mgmt.Resolver
|
||||
|
||||
@@ -238,9 +230,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@@ -269,41 +259,9 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
}
|
||||
|
||||
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||
func (s *DefaultServer) BeginBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode enabled")
|
||||
s.batchMode = true
|
||||
}
|
||||
|
||||
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||
func (s *DefaultServer) EndBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode disabled, applying accumulated changes")
|
||||
s.batchMode = false
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
// CancelBatch cancels batch mode without applying accumulated changes.
|
||||
// This is useful when operations fail partway through and you want to
|
||||
// discard partial state rather than applying it.
|
||||
func (s *DefaultServer) CancelBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
|
||||
s.batchMode = false
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||
|
||||
@@ -481,17 +439,6 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||
skipProbe, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||
}
|
||||
if skipProbe {
|
||||
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
wg.Add(1)
|
||||
@@ -561,7 +508,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
// Always apply host config for management updates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
@@ -669,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
s.registerFallback(config)
|
||||
}
|
||||
|
||||
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||
if !ok {
|
||||
@@ -678,7 +624,6 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -926,7 +871,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver goes down, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
@@ -962,7 +906,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver reactivates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
|
||||
@@ -18,12 +18,7 @@ func TestGetServerDns(t *testing.T) {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
mockSrvB, ok := srvB.(*MockServer)
|
||||
if !ok {
|
||||
t.Errorf("returned server is not a MockServer")
|
||||
}
|
||||
|
||||
if mockSrvB != srv {
|
||||
if srvB != srv {
|
||||
t.Errorf("mismatch dns instances")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,21 +8,15 @@ import (
|
||||
|
||||
type MockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
lastResponse *dns.Msg
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
rw.lastResponse = m
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||
return rw.lastResponse
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
|
||||
@@ -351,13 +351,9 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
err := backoff.Retry(operation, exponentialBackOff)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -190,75 +190,50 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
question := query.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
|
||||
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
domain := strings.ToLower(question.Name)
|
||||
|
||||
resp := query.SetReply(query)
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
if network == "" {
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||
// query doesn't match any configured domain
|
||||
if mostSpecificResId == "" {
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
return
|
||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||
return nil
|
||||
}
|
||||
|
||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||
f.cache.set(qname, question.Qtype, result.IPs)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||
f.cache.set(domain, question.Qtype, result.IPs)
|
||||
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
type udpResponseWriter struct {
|
||||
dns.ResponseWriter
|
||||
query *dns.Msg
|
||||
}
|
||||
|
||||
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
opt := u.query.IsEdns0()
|
||||
maxSize := dns.MinMsgSize
|
||||
if opt != nil {
|
||||
maxSize = int(opt.UDPSize())
|
||||
}
|
||||
|
||||
if resp.Len() > maxSize {
|
||||
resp.Truncate(maxSize)
|
||||
}
|
||||
|
||||
return u.ResponseWriter.WriteMsg(resp)
|
||||
return resp
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
@@ -268,7 +243,30 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt := query.IsEdns0()
|
||||
maxSize := dns.MinMsgSize
|
||||
if opt != nil {
|
||||
// client advertised a larger EDNS0 buffer
|
||||
maxSize = int(opt.UDPSize())
|
||||
}
|
||||
|
||||
// if our response is too big, truncate and set the TC bit
|
||||
if resp.Len() > maxSize {
|
||||
resp.Truncate(maxSize)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
@@ -278,7 +276,18 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||
@@ -325,7 +334,6 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
@@ -335,7 +343,9 @@ func (f *DNSForwarder) handleDNSError(
|
||||
// NotFound: cache negative result and respond
|
||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -345,7 +355,9 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -353,7 +365,9 @@ func (f *DNSForwarder) handleDNSError(
|
||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||
resp.Rcode = verifyResult.Rcode
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -361,12 +375,15 @@ func (f *DNSForwarder) handleDNSError(
|
||||
// No cache or verification failed. Log with or without the server field for more context.
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
} else {
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||
|
||||
@@ -318,9 +318,8 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||
@@ -330,9 +329,10 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockFirewall.AssertExpectations(t)
|
||||
mockResolver.AssertExpectations(t)
|
||||
} else {
|
||||
require.NotNil(t, resp, "Expected response")
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
if resp != nil {
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
}
|
||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||
}
|
||||
@@ -466,16 +466,14 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||
|
||||
// Verify response
|
||||
resp := mockWriter.GetLastResponse()
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.NotEmpty(t, resp.Answer)
|
||||
} else {
|
||||
require.NotNil(t, resp, "Expected response")
|
||||
} else if resp != nil {
|
||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||
"Unauthorized domain should be refused or have no answers")
|
||||
}
|
||||
@@ -530,10 +528,9 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// Verify response contains all IPs
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||
@@ -608,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
@@ -678,8 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||
resp1 := w1.GetLastResponse()
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -687,13 +683,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w2 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
|
||||
resp2 := w2.GetLastResponse()
|
||||
require.NotNil(t, resp2, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||
require.Len(t, resp2.Answer, 1)
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
@@ -719,8 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||
resp1 := w1.GetLastResponse()
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -732,13 +727,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
w2 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
|
||||
resp2 := w2.GetLastResponse()
|
||||
require.NotNil(t, resp2)
|
||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||
require.Len(t, resp2.Answer, 1)
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
@@ -789,9 +784,8 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
@@ -903,15 +897,26 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||
if resp != nil && writtenResp == nil {
|
||||
writtenResp = resp
|
||||
}
|
||||
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||
|
||||
if tt.expectNoAnswer {
|
||||
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||
}
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
@@ -926,8 +931,15 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
writeCalled := false
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writeCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||
assert.Nil(t, resp, "Should return nil for empty query")
|
||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,12 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -54,11 +52,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -74,6 +74,7 @@ import (
|
||||
const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
@@ -206,6 +207,7 @@ type Engine struct {
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
@@ -221,8 +223,6 @@ type Engine struct {
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -265,6 +265,7 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
@@ -417,7 +418,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
@@ -505,10 +505,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
// Set up notrack rules immediately after proxy is listening to prevent
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -543,12 +539,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.shutdownWg.Add(1)
|
||||
wgIfaceName := e.wgInterface.Name()
|
||||
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.triggerClientRestart()
|
||||
} else if err != nil {
|
||||
@@ -574,11 +569,9 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create firewall manager: %w", err)
|
||||
}
|
||||
if e.firewall == nil {
|
||||
return fmt.Errorf("create firewall manager: received nil manager")
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.initFirewall(); err != nil {
|
||||
@@ -624,23 +617,6 @@ func (e *Engine) initFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||
func (e *Engine) setupWGProxyNoTrack() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
proxyPort := e.wgInterface.GetProxyPort()
|
||||
if proxyPort == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
@@ -800,7 +776,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// stop and cleanup if disabled
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
@@ -829,10 +805,6 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
log.Infof("sync finished in %s", time.Since(started))
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -1022,7 +994,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||
state.FQDN = conf.GetFqdn()
|
||||
|
||||
e.statusRecorder.UpdateLocalPeerState(state)
|
||||
@@ -1538,6 +1510,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1560,10 +1533,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
start := time.Now()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
gotLock := time.Since(start)
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -1587,8 +1558,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
@@ -1675,7 +1644,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
@@ -1822,18 +1790,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
// GetFirewallManager returns the firewall manager.
|
||||
// GetFirewallManager returns the firewall manager
|
||||
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||
return e.firewall
|
||||
}
|
||||
|
||||
// GetExposeManager returns the expose session manager.
|
||||
func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
@@ -1933,7 +1894,7 @@ func (e *Engine) triggerClientRestart() {
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||
if !e.config.NetworkMonitor {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
@@ -95,10 +94,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
if netstack.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
@@ -221,10 +216,6 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
if netstack.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
|
||||
@@ -107,7 +107,6 @@ type MockWGIface struct {
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetProxyPortFunc func() uint16
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
@@ -204,13 +203,6 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||
if m.GetProxyPortFunc != nil {
|
||||
return m.GetProxyPortFunc()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const renewTimeout = 10 * time.Second
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
NamePrefix string
|
||||
Domain string
|
||||
Port uint16
|
||||
Protocol int
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
}
|
||||
|
||||
type ManagementClient interface {
|
||||
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
|
||||
RenewExpose(ctx context.Context, domain string) error
|
||||
StopExpose(ctx context.Context, domain string) error
|
||||
}
|
||||
|
||||
// Manager handles expose session lifecycle via the management client.
|
||||
type Manager struct {
|
||||
mgmClient ManagementClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewManager creates a new expose Manager using the given management client.
|
||||
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
|
||||
return &Manager{mgmClient: mgmClient, ctx: ctx}
|
||||
}
|
||||
|
||||
// Expose creates a new expose session via the management server.
|
||||
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||
log.Infof("exposing service on port %d", req.Port)
|
||||
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("expose session created for %s", resp.Domain)
|
||||
|
||||
return fromClientExposeResponse(resp), nil
|
||||
}
|
||||
|
||||
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
defer m.stop(domain)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("context canceled, stopping keep alive for %s", domain)
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := m.renew(ctx, domain); err != nil {
|
||||
log.Errorf("renewing expose session for %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renew extends the TTL of an active expose session.
|
||||
func (m *Manager) renew(ctx context.Context, domain string) error {
|
||||
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
|
||||
defer cancel()
|
||||
return m.mgmClient.RenewExpose(renewCtx, domain)
|
||||
}
|
||||
|
||||
// stop terminates an active expose session.
|
||||
func (m *Manager) stop(domain string) {
|
||||
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
|
||||
defer cancel()
|
||||
err := m.mgmClient.StopExpose(stopCtx, domain)
|
||||
if err != nil {
|
||||
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
func TestManager_Expose_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return &mgm.ExposeResponse{
|
||||
ServiceName: "my-service",
|
||||
ServiceURL: "https://my-service.example.com",
|
||||
Domain: "my-service.example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
result, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
|
||||
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
|
||||
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
|
||||
}
|
||||
|
||||
func TestManager_Expose_Error(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return nil, errors.New("permission denied")
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
_, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
|
||||
}
|
||||
|
||||
func TestManager_Renew_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
err := m.renew(context.Background(), "my-service.example.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestManager_Renew_Timeout(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
return ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(ctx, mock)
|
||||
err := m.renew(ctx, "my-service.example.com")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
req := &daemonProto.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
|
||||
Pin: "123456",
|
||||
Password: "secret",
|
||||
UserGroups: []string{"group1", "group2"},
|
||||
Domain: "custom.example.com",
|
||||
NamePrefix: "my-prefix",
|
||||
}
|
||||
|
||||
exposeReq := NewRequest(req)
|
||||
|
||||
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
|
||||
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
|
||||
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
return &Request{
|
||||
Port: uint16(req.Port),
|
||||
Protocol: int(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
Domain: req.Domain,
|
||||
NamePrefix: req.NamePrefix,
|
||||
}
|
||||
}
|
||||
|
||||
func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
return mgm.ExposeRequest{
|
||||
NamePrefix: req.NamePrefix,
|
||||
Domain: req.Domain,
|
||||
Port: req.Port,
|
||||
Protocol: req.Protocol,
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
}
|
||||
}
|
||||
|
||||
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||
return &Response{
|
||||
ServiceName: response.ServiceName,
|
||||
Domain: response.Domain,
|
||||
ServiceURL: response.ServiceURL,
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,6 @@ type wgIfaceBase interface {
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
@@ -75,13 +74,12 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is used on Windows, JS, and netstack platforms:
|
||||
// BindListener is only used on Windows and JS platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||
// BindListener bypasses these issues by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||
// BindListener bypasses this by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
|
||||
201
client/internal/login.go
Normal file
201
client/internal/login.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
||||
mgmURL := config.ManagementURL
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
info.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -22,56 +22,51 @@ func prepareFd() (int, error) {
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
for {
|
||||
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
|
||||
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
|
||||
if err := waitReadable(ctx, fd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
|
||||
return fmt.Errorf("routing socket closed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("read routing socket: %w", err)
|
||||
}
|
||||
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,33 +90,3 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||
func waitReadable(ctx context.Context, fd int) error {
|
||||
var fdset unix.FdSet
|
||||
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
|
||||
return fmt.Errorf("fd %d out of range for FdSet", fd)
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fdset = unix.FdSet{}
|
||||
fdset.Set(fd)
|
||||
// Use a 1-second timeout so we can re-check ctx periodically.
|
||||
tv := unix.Timeval{Sec: 1}
|
||||
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("select on routing socket: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
return nil
|
||||
}
|
||||
// timeout — loop back and re-check ctx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
@@ -32,6 +34,7 @@ type ServiceDependencies struct {
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
Semaphore *semaphoregroup.SemaphoreGroup
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
}
|
||||
|
||||
@@ -108,8 +111,9 @@ type Conn struct {
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
wg sync.WaitGroup
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -135,6 +139,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
@@ -149,10 +154,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
if err := conn.semaphore.Add(engineCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.opened {
|
||||
conn.semaphore.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,6 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
conn.semaphore.Done()
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
@@ -196,6 +207,10 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
|
||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||
conn.semaphore.Done()
|
||||
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
}()
|
||||
conn.opened = true
|
||||
@@ -375,8 +390,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
@@ -389,13 +402,15 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -415,18 +430,14 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.dumpState.SwitchToRelay()
|
||||
if sessionChanged {
|
||||
conn.resetEndpoint()
|
||||
}
|
||||
|
||||
// todo consider to move after the ConfigureWGEndpoint
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
@@ -488,22 +499,19 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
controller := isController(conn.config)
|
||||
|
||||
if controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
if !controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
@@ -655,6 +663,19 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
maxWait := 300
|
||||
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
|
||||
|
||||
timeout := time.NewTimer(duration)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-timeout.C:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay, conntype.ICETurn:
|
||||
@@ -735,17 +756,6 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) resetEndpoint() {
|
||||
if !isController(conn.config) {
|
||||
return
|
||||
}
|
||||
conn.Log.Infof("reset wg endpoint")
|
||||
conn.wgWatcher.Reset()
|
||||
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
}
|
||||
@@ -851,3 +861,9 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
@@ -52,6 +53,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
sd := ServiceDependencies{
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
@@ -69,6 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
@@ -107,6 +110,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
|
||||
@@ -34,27 +34,28 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||
// The initiator immediately configures the endpoint, while the non-initiator
|
||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.initiator {
|
||||
e.log.Debugf("configure up WireGuard as initiator")
|
||||
return e.configureAsInitiator(addr, presharedKey)
|
||||
e.log.Debugf("configure up WireGuard as initiatr")
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
}
|
||||
|
||||
e.log.Debugf("configure up WireGuard as responder")
|
||||
return e.configureAsResponder(addr, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
return e.updateWireGuardPeer(nil, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
@@ -65,38 +66,6 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
@@ -132,9 +101,3 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -33,6 +32,24 @@ type ThreadSafeAgent struct {
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- a.Agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
@@ -76,41 +93,9 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agent == nil {
|
||||
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||
}
|
||||
|
||||
return &ThreadSafeAgent{Agent: agent}, nil
|
||||
}
|
||||
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
// Defensive check to prevent nil pointer dereference
|
||||
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||
agent := a.Agent
|
||||
if agent == nil {
|
||||
log.Warnf("ICE agent is nil during close, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func GenerateICECredentials() (string, string, error) {
|
||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,8 +32,6 @@ type WGWatcher struct {
|
||||
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
|
||||
resetCh chan struct{}
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -42,7 +40,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
wgIfaceStater: wgIfaceStater,
|
||||
peerKey: peerKey,
|
||||
stateDump: stateDump,
|
||||
resetCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,15 +76,6 @@ func (w *WGWatcher) IsEnabled() bool {
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||
func (w *WGWatcher) Reset() {
|
||||
select {
|
||||
case w.resetCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
@@ -117,12 +105,6 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
w.stateDump.WGcheckSuccess()
|
||||
|
||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||
case <-w.resetCh:
|
||||
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||
lastHandshake = time.Time{}
|
||||
enabledTime = time.Now()
|
||||
timer.Stop()
|
||||
timer.Reset(wgHandshakeOvertime)
|
||||
case <-ctx.Done():
|
||||
w.log.Infof("WireGuard watcher stopped")
|
||||
return
|
||||
|
||||
@@ -52,9 +52,8 @@ type WorkerICE struct {
|
||||
// increase by one when disconnecting the agent
|
||||
// with it the remote peer can discard the already deprecated offer/answer
|
||||
// Without it the remote peer may recreate a workable ICE connection
|
||||
sessionID ICESessionID
|
||||
remoteSessionChanged bool
|
||||
muxAgent sync.Mutex
|
||||
sessionID ICESessionID
|
||||
muxAgent sync.Mutex
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
@@ -107,12 +106,9 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
w.remoteSessionChanged = true
|
||||
w.agentDialerCancel()
|
||||
if w.agent != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
@@ -165,10 +161,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
if candidateInCGNAT(candidate, w.config.WgConfig.WgInterface.Address().Network) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
@@ -312,17 +304,13 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
||||
cancel()
|
||||
if err := agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
sessionChanged := w.remoteSessionChanged
|
||||
w.remoteSessionChanged = false
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
@@ -335,7 +323,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
return sessionChanged
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
@@ -366,10 +354,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
if candidateInCGNAT(candidate, w.config.WgConfig.WgInterface.Address().Network) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
w.log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
@@ -440,11 +424,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected(sessionChanged)
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
default:
|
||||
return
|
||||
@@ -504,11 +488,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
// cgnatPrefix is the RFC 6598 Carrier-Grade NAT range (100.64.0.0/10).
|
||||
// Addresses in this range are used by CNI plugins (Cilium, Calico, etc.) for pod networking
|
||||
// and are not suitable for direct peer-to-peer connectivity between hosts.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
@@ -537,32 +516,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
|
||||
return false
|
||||
}
|
||||
|
||||
// candidateInCGNAT checks if a candidate address falls within the RFC 6598 CGNAT range (100.64.0.0/10).
|
||||
// These addresses are commonly used by Kubernetes CNI plugins (Cilium, Calico) for pod networking
|
||||
// and are not routable between hosts, making them unsuitable as ICE candidates.
|
||||
// The wgNetwork parameter is the NetBird WireGuard network prefix — if the candidate address is within
|
||||
// this network, it is not filtered here (it's handled separately by the NetBird network check).
|
||||
func candidateInCGNAT(candidate ice.Candidate, wgNetwork netip.Prefix) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cgnatPrefix.Contains(addr) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Don't filter if the address is within the WireGuard network itself —
|
||||
// that's handled by the NetBird network membership check elsewhere.
|
||||
if wgNetwork.IsValid() && wgNetwork.Contains(addr) {
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("Ignoring candidate [%s], its address %s is in the CGNAT range (%s) likely assigned by a CNI plugin",
|
||||
candidate.String(), addr, cgnatPrefix)
|
||||
return true
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
138
client/internal/pkce_auth.go
Normal file
138
client/internal/pkce_auth.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig PKCEAuthProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
authFlow := PKCEAuthorizationFlow{
|
||||
ProviderConfig: PKCEAuthProviderConfig{
|
||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||
},
|
||||
}
|
||||
|
||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return authFlow, nil
|
||||
}
|
||||
|
||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -20,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -39,8 +42,6 @@ const (
|
||||
var DefaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
// Kubernetes CNI interfaces
|
||||
"cilium_", "cilium", "lxc", "cali", "flannel", "cni", "weave",
|
||||
}
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
@@ -97,7 +98,6 @@ type Config struct {
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
IFaceBlackListAppliedDefaults []string `json:",omitempty"`
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
|
||||
configDir := filepath.Join(DefaultConfigPathDir, username)
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(configDir, 0700); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0600); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -206,15 +206,9 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
func fileExists(path string) (bool, error) {
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -258,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
}
|
||||
|
||||
if config.AdminURL == nil {
|
||||
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -357,7 +351,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if changed := config.mergeDefaultIFaceBlacklist(); changed {
|
||||
if len(config.IFaceBlackList) == 0 {
|
||||
log.Infof("filling in interface blacklist with defaults: [ %s ]",
|
||||
strings.Join(DefaultInterfaceBlacklist, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, DefaultInterfaceBlacklist...)
|
||||
updated = true
|
||||
}
|
||||
|
||||
@@ -591,37 +588,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// mergeDefaultIFaceBlacklist ensures that new entries added to DefaultInterfaceBlacklist
|
||||
// are merged into an existing IFaceBlackList on upgrade, while respecting entries that
|
||||
// the user deliberately removed. It tracks which defaults have been offered via
|
||||
// IFaceBlackListAppliedDefaults so removals are not undone.
|
||||
func (config *Config) mergeDefaultIFaceBlacklist() (updated bool) {
|
||||
if len(config.IFaceBlackList) == 0 {
|
||||
log.Infof("filling in interface blacklist with defaults: [ %s ]",
|
||||
strings.Join(DefaultInterfaceBlacklist, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, DefaultInterfaceBlacklist...)
|
||||
config.IFaceBlackListAppliedDefaults = append([]string{}, DefaultInterfaceBlacklist...)
|
||||
return true
|
||||
}
|
||||
|
||||
// Find defaults not yet tracked in AppliedDefaults — these are genuinely new.
|
||||
// Entries already in AppliedDefaults were either kept or deliberately removed by the user.
|
||||
newDefaults := util.SliceDiff(DefaultInterfaceBlacklist, config.IFaceBlackListAppliedDefaults)
|
||||
if len(newDefaults) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only add entries not already present in the blacklist (avoid duplicates)
|
||||
toAdd := util.SliceDiff(newDefaults, config.IFaceBlackList)
|
||||
if len(toAdd) > 0 {
|
||||
log.Infof("merging new default interface blacklist entries: [ %s ]",
|
||||
strings.Join(toAdd, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, toAdd...)
|
||||
}
|
||||
config.IFaceBlackListAppliedDefaults = append(config.IFaceBlackListAppliedDefaults, newDefaults...)
|
||||
return true
|
||||
}
|
||||
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
@@ -667,3 +633,273 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err := util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
func update(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
|
||||
func GetConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, false)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
return createNewConfig(input)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, true)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
if fileExists(configPath) {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// initialize through apply() without changes
|
||||
if changed, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, err
|
||||
} else if changed {
|
||||
if err = WriteOutConfig(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
} else if !createIfMissing {
|
||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = WriteOutConfig(configPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectWriteOutConfig writes config directly without atomic temp file operations.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectWriteOutConfig(path string, config *Config) error {
|
||||
return util.DirectWriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
|
||||
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
|
||||
if err := util.EnforcePermission(input.ConfigPath); err != nil {
|
||||
log.Errorf("failed to enforce permission on config file: %v", err)
|
||||
}
|
||||
|
||||
return directUpdate(input)
|
||||
}
|
||||
|
||||
func directUpdate(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ConfigToJSON serializes a Config struct to a JSON string.
|
||||
// This is useful for exporting config to alternative storage mechanisms
|
||||
// (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
func ConfigToJSON(config *Config) (string, error) {
|
||||
bs, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// ConfigFromJSON deserializes a JSON string to a Config struct.
|
||||
// This is useful for restoring config from alternative storage mechanisms.
|
||||
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
|
||||
func ConfigFromJSON(jsonStr string) (*Config, error) {
|
||||
config := &Config{}
|
||||
err := json.Unmarshal([]byte(jsonStr), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply defaults to ensure required fields are initialized.
|
||||
// This mirrors what readConfig does after loading from file.
|
||||
if _, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err = util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
func update(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
|
||||
func GetConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, false)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
return createNewConfig(input)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, true)
|
||||
}
|
||||
|
||||
// readConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
configExists, err := fileExists(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
|
||||
if configExists {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// initialize through apply() without changes
|
||||
if changed, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, err
|
||||
} else if changed {
|
||||
if err = WriteOutConfig(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
} else if !createIfMissing {
|
||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = WriteOutConfig(configPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectWriteOutConfig writes config directly without atomic temp file operations.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectWriteOutConfig(path string, config *Config) error {
|
||||
return util.DirectWriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
|
||||
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
|
||||
if err := util.EnforcePermission(input.ConfigPath); err != nil {
|
||||
log.Errorf("failed to enforce permission on config file: %v", err)
|
||||
}
|
||||
|
||||
return directUpdate(input)
|
||||
}
|
||||
|
||||
func directUpdate(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ConfigToJSON serializes a Config struct to a JSON string.
|
||||
// This is useful for exporting config to alternative storage mechanisms
|
||||
// (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
func ConfigToJSON(config *Config) (string, error) {
|
||||
bs, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// ConfigFromJSON deserializes a JSON string to a Config struct.
|
||||
// This is useful for restoring config from alternative storage mechanisms.
|
||||
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
|
||||
func ConfigFromJSON(jsonStr string) (*Config, error) {
|
||||
config := &Config{}
|
||||
err := json.Unmarshal([]byte(jsonStr), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply defaults to ensure required fields are initialized.
|
||||
// This mirrors what readConfig does after loading from file.
|
||||
if _, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -108,87 +108,6 @@ func TestExtraIFaceBlackList(t *testing.T) {
|
||||
assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1")
|
||||
}
|
||||
|
||||
func TestIFaceBlackListMigratesNewDefaults(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
// Create a config that simulates an old install with a partial IFaceBlackList
|
||||
// (missing the newer CNI entries like "cilium_", "cali", etc.)
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate an old config that predates AppliedDefaults tracking:
|
||||
// it has only the original entries, no CNI prefixes, and no AppliedDefaults.
|
||||
oldList := []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
||||
config.IFaceBlackList = oldList
|
||||
config.IFaceBlackListAppliedDefaults = nil
|
||||
err = WriteOutConfig(configPath, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-read the config — apply() should merge in missing defaults
|
||||
reloaded, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, entry := range DefaultInterfaceBlacklist {
|
||||
assert.Contains(t, reloaded.IFaceBlackList, entry,
|
||||
"IFaceBlackList should contain default entry %q after migration", entry)
|
||||
}
|
||||
|
||||
// Verify no duplicates were introduced
|
||||
seen := make(map[string]bool)
|
||||
for _, entry := range reloaded.IFaceBlackList {
|
||||
assert.False(t, seen[entry], "duplicate entry %q in IFaceBlackList", entry)
|
||||
seen[entry] = true
|
||||
}
|
||||
|
||||
// AppliedDefaults should now track all current defaults
|
||||
for _, entry := range DefaultInterfaceBlacklist {
|
||||
assert.Contains(t, reloaded.IFaceBlackListAppliedDefaults, entry,
|
||||
"AppliedDefaults should track %q", entry)
|
||||
}
|
||||
|
||||
// Re-read again — should not change (idempotent)
|
||||
reloaded2, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, reloaded.IFaceBlackList, reloaded2.IFaceBlackList,
|
||||
"IFaceBlackList should be stable on subsequent reads")
|
||||
}
|
||||
|
||||
func TestIFaceBlackListRespectsUserRemoval(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
// Create a fresh config (all defaults applied)
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, config.IFaceBlackList, "cali")
|
||||
|
||||
// User deliberately removes "cali" from their blacklist
|
||||
filtered := make([]string, 0, len(config.IFaceBlackList))
|
||||
for _, entry := range config.IFaceBlackList {
|
||||
if entry != "cali" {
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
}
|
||||
config.IFaceBlackList = filtered
|
||||
err = WriteOutConfig(configPath, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-read — "cali" should NOT be re-added because it's in AppliedDefaults
|
||||
reloaded, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, reloaded.IFaceBlackList, "cali",
|
||||
"user-removed entry should not be re-added")
|
||||
|
||||
// AppliedDefaults should still contain "cali" (it was offered)
|
||||
assert.Contains(t, reloaded.IFaceBlackListAppliedDefaults, "cali")
|
||||
}
|
||||
|
||||
func TestHiddenPreSharedKey(t *testing.T) {
|
||||
hidden := "**********"
|
||||
samplePreSharedKey := "mysecretpresharedkey"
|
||||
|
||||
@@ -256,11 +256,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
if fileExists(profPath) {
|
||||
return ErrProfileAlreadyExists
|
||||
}
|
||||
|
||||
@@ -289,11 +285,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if !profileExists {
|
||||
if !fileExists(profPath) {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,7 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
}
|
||||
if !stateFileExists {
|
||||
if !fileExists(stateFile) {
|
||||
return nil, errors.New("profile state file does not exist")
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user