Compare commits

..

91 Commits

Author SHA1 Message Date
Viktor Liu
24b66fb406 Translate usernames to UPN format for domain login 2025-11-05 22:27:08 +01:00
Viktor Liu
9378b6b0a3 Merge branch 'ssh-rewrite' into move-licensed-code 2025-11-05 16:09:03 +01:00
Viktor Liu
3779a3385f Fix tests 2025-11-05 13:06:54 +01:00
Viktor Liu
b5d75ad9c4 Go fmt everything 2025-11-05 12:59:36 +01:00
Viktor Liu
8db91abfdf Merge branch 'main' into ssh-rewrite 2025-11-05 12:44:17 +01:00
Viktor Liu
6f817cad6d Remove duplicate code 2025-11-03 13:47:33 +01:00
Viktor Liu
e3bb8c1b7b Merge branch 'main' into ssh-rewrite 2025-11-03 13:43:29 +01:00
Viktor Liu
107066fa3d Merge branch 'main' into ssh-rewrite 2025-10-28 22:08:46 +01:00
Viktor Liu
a7a85d4dc8 Fix tests 2025-10-28 21:11:45 +01:00
Viktor Liu
576b4a779c Log shell 2025-10-28 18:15:53 +01:00
Viktor Liu
e6854dfd99 Improve session logging 2025-10-28 17:57:59 +01:00
Viktor Liu
6f14134988 Merge branch 'main' into ssh-rewrite 2025-10-28 16:50:23 +01:00
Viktor Liu
4fd64379da Move client-imported GPL code to separate package 2025-10-23 23:52:44 +02:00
Viktor Liu
c20202a6c3 Add new flags to test 2025-10-17 16:15:05 +02:00
Viktor Liu
4386a21956 Merge branch 'main' into ssh-rewrite 2025-10-17 15:34:36 +02:00
Zoltan Papp
5882daf5d9 Force relay connection, do not waste signaling resources on ICE connection (#4628) 2025-10-13 11:02:21 +02:00
Viktor Liu
11d71e6e22 Ignore default log file 2025-10-10 16:21:39 +02:00
Viktor Liu
4dadcfd9bd Remove client.log check 2025-10-10 16:17:46 +02:00
Viktor Liu
34b55c600e Log errors on debug 2025-10-10 16:11:13 +02:00
Viktor Liu
316c0afa9a Remove unused arg 2025-10-10 11:08:34 +02:00
Viktor Liu
cf97799db8 Fix test 2025-10-10 10:23:45 +02:00
Viktor Liu
4d297205c3 Fix test build 2025-10-09 17:26:25 +02:00
Viktor Liu
559f6aeeaf Improve logging 2025-10-08 18:54:56 +02:00
Viktor Liu
7216c201da Log priv check errors 2025-10-08 18:46:02 +02:00
Viktor Liu
4d89d0f115 Remove unused code 2025-10-08 18:39:41 +02:00
Viktor Liu
610c880ec9 Fix missing jwt config passed to peers 2025-10-08 16:47:11 +02:00
Viktor Liu
19adcb5f63 Merge branch 'main' into ssh-rewrite 2025-10-08 12:40:07 +02:00
Viktor Liu
f3d31698da Skip some auth tests on windows that are already covered 2025-10-07 23:39:01 +02:00
Viktor Liu
d9efe4e944 Add ssh authenatication with jwt (#4550) 2025-10-07 23:38:27 +02:00
Viktor Liu
7e0bbaaa3c Merge branch 'main' into ssh-rewrite 2025-10-07 09:41:07 +02:00
Viktor Liu
b3c7b3c7b2 Fix js build 2025-10-02 15:59:17 +02:00
Viktor Liu
66483ab48d Merge branch 'main' into ssh-rewrite 2025-10-02 15:53:12 +02:00
Viktor Liu
5272fc2b18 Merge branch 'main' into ssh-rewrite 2025-09-25 11:12:47 +02:00
Viktor Liu
4c53372815 Add missing flags 2025-08-27 09:59:12 +02:00
Viktor Liu
79d28b71ee Improve forwarding cancellation 2025-08-26 22:22:15 +02:00
Viktor Liu
77a352763d Fix button style 2025-08-26 21:19:04 +02:00
Viktor Liu
cdd5c6c005 Address review 2025-08-26 21:01:55 +02:00
Viktor Liu
b1a9242c98 Fix merge commit changes 2025-08-26 20:43:29 +02:00
Viktor Liu
b43ef4f17b Merge branch 'main' into ssh-rewrite 2025-08-26 20:09:47 +02:00
Viktor Liu
758a97c352 Generate ssh_config independently of ssh server 2025-07-14 22:02:41 +02:00
Viktor Liu
d93b7c2f38 Fix known hosts entries 2025-07-14 21:41:59 +02:00
Viktor Liu
fa893aa0a4 Fix build 2025-07-12 00:49:08 +02:00
Viktor Liu
ac7120871b Fix proto 2025-07-12 00:11:31 +02:00
Viktor Liu
9a7daa132e Fix client ssh file 2025-07-11 22:08:28 +02:00
Viktor Liu
cdded8c22e Merge branch 'main' into ssh-rewrite 2025-07-11 22:05:12 +02:00
Viktor Liu
e4e0b8fff9 Remove empty file 2025-07-04 17:09:54 +02:00
Viktor Liu
a4b067553d Merge branch 'main' into ssh-rewrite 2025-07-04 16:53:54 +02:00
Viktor Liu
088956645f Fix username validation and skip ci tests properly 2025-07-03 15:36:42 +02:00
Viktor Liu
aa30b7afe8 More windows tests 2025-07-03 14:11:20 +02:00
Viktor Liu
f1bb4d2ac3 Fix more Windows tests 2025-07-03 13:35:53 +02:00
Viktor Liu
982841e25b Test up tests users if none are available on CI 2025-07-03 12:33:31 +02:00
Viktor Liu
a476b8d12f Fix more windows tests 2025-07-03 11:26:04 +02:00
Viktor Liu
a21f924b26 Fix some windows tests 2025-07-03 10:20:16 +02:00
Viktor Liu
9e51d2e8fb Fix lint and sonar 2025-07-03 09:58:25 +02:00
Viktor Liu
3e490d974c Remove duplicated code 2025-07-03 03:40:27 +02:00
Viktor Liu
04bb314426 Allow sftp same user switching on windows 2025-07-03 02:19:12 +02:00
Viktor Liu
6e15882c11 Fix tests and windows username validation 2025-07-03 01:58:15 +02:00
Viktor Liu
76f9e11b29 Fix tests 2025-07-03 01:07:58 +02:00
Viktor Liu
612de2c784 Remove socketfilter temporarily 2025-07-02 22:00:10 +02:00
Viktor Liu
1fdde66c31 More lint 2025-07-02 21:55:25 +02:00
Viktor Liu
5970591d24 Fix lint 2025-07-02 21:32:39 +02:00
Viktor Liu
0d5408baec Fix lint 2025-07-02 21:04:58 +02:00
Viktor Liu
96084e3a02 Reduce complexity 2025-07-02 20:43:17 +02:00
Viktor Liu
4bbca28eb6 Fix lint 2025-07-02 20:23:23 +02:00
Viktor Liu
279b77dee0 Bump sftp 2025-07-02 19:42:57 +02:00
Viktor Liu
9d1554f9f7 Complete overhaul 2025-07-02 19:35:19 +02:00
Viktor Liu
f56075ca15 Tidy mod 2025-07-02 19:34:36 +02:00
Viktor Liu
6ed846ae29 Refactor ssh server and client 2025-07-02 19:34:36 +02:00
Viktor Liu
520f2cfdb4 Remove implicit inbound ssh firewall rules and change default port 2025-07-02 19:34:32 +02:00
Viktor Liu
0f79a8942d Fix route notificaiton 2025-07-02 17:24:14 +02:00
Viktor Liu
5299e9fda3 Merge branch 'main' into android-dns-routes 2025-07-02 15:23:14 +02:00
Viktor Liu
11bdf5b3a5 Use r 2025-06-26 15:41:56 +02:00
Viktor Liu
5fc95d4a0c Display domains properly 2025-06-26 15:36:14 +02:00
Viktor Liu
c7884039b8 Revert "Fix errorf"
This reverts commit 26fc32f1be.
2025-06-25 15:17:31 +02:00
Viktor Liu
26fc32f1be Fix errorf 2025-06-25 15:03:55 +02:00
Viktor Liu
a79cb1c11b Merge branch 'main' into android-dns-routes 2025-06-18 17:27:13 +02:00
Viktor Liu
306d75fe1a Set up fake ip route only if the dns feature flag is enabled 2025-06-17 22:29:13 +02:00
Viktor Liu
9468e69c8c Extract static error 2025-06-17 21:47:05 +02:00
Viktor Liu
f51ce7cee5 Remove nil checks 2025-06-17 21:41:58 +02:00
Viktor Liu
d47c6b624e Fix spelling 2025-06-17 20:02:52 +02:00
Viktor Liu
471f90e8db Rename methods 2025-06-17 15:52:34 +02:00
Viktor Liu
1a3b04d2fe Swap tracking and nat order 2025-06-17 15:45:22 +02:00
Viktor Liu
51b9e93eb9 Merge branch 'main' into android-dns-routes 2025-06-17 15:12:05 +02:00
Viktor Liu
2952669e97 Fix lint 2025-06-17 14:16:59 +02:00
Viktor Liu
7cd44a9a3c Improve nat perf 2025-06-17 13:55:57 +02:00
Viktor Liu
8684981b57 Add tests 2025-06-17 13:41:06 +02:00
Viktor Liu
8e94d85d14 Rename test files 2025-06-17 12:46:17 +02:00
Viktor Liu
631b77dc3c Remove some allocations 2025-06-17 12:44:52 +02:00
Viktor Liu
50ac3d437e Fix lint issues 2025-06-17 03:07:28 +02:00
Viktor Liu
49bbd90557 Fix test 2025-06-17 02:57:15 +02:00
Viktor Liu
bb74e903cd Implement dns routes for Android 2025-06-17 02:48:13 +02:00
601 changed files with 9668 additions and 66886 deletions

View File

@@ -1,15 +1,15 @@
FROM golang:1.25-bookworm FROM golang:1.23-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\ && apt-get -y install --no-install-recommends\
gettext-base=0.21-12 \ gettext-base=0.21-4 \
iptables=1.8.9-2 \ iptables=1.8.7-1 \
libgl1-mesa-dev=22.3.6-1+deb12u1 \ libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+23 \ xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.92-1 \ libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@latest && go install -v golang.org/x/tools/gopls@v0.18.1
WORKDIR /app WORKDIR /app

View File

@@ -1,11 +0,0 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -3,108 +3,40 @@ name: Check License Dependencies
on: on:
push: push:
branches: [ main ] branches: [ main ]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request: pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs: jobs:
check-internal-dependencies: check-dependencies:
name: Check Internal AGPL Dependencies
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Go - name: Check for problematic license dependencies
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: | run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..." echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo "" echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages # Find all directories except the problematic ones and system dirs
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true) FOUND_ISSUES=0
while IFS= read -r dir; do
if [ -n "$COPYLEFT_DEPS" ]; then echo "=== Checking $dir ==="
echo "Found copyleft licensed dependencies:" # Search for problematic imports, excluding test files
echo "$COPYLEFT_DEPS" RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
echo "" if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
# Filter out dependencies that are only pulled in by internal AGPL packages echo "$RESULTS"
INCOMPATIBLE="" FOUND_ISSUES=1
while IFS=',' read -r package url license; do else
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then echo "✓ No problematic dependencies found"
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi fi
fi done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo "✅ All external license dependencies are compatible with BSD-3-Clause" echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo "✅ All license dependencies are clean"
fi

View File

@@ -15,14 +15,13 @@ jobs:
name: "Client / Unit" name: "Client / Unit"
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2" release: "14.2"
prepare: | prepare: |
pkg install -y curl pkgconf xorg pkg install -y curl pkgconf xorg
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz" GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL" GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL" curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL" tar -C /usr/local -vxzf "$GO_TARBALL"
@@ -39,7 +39,7 @@ jobs:
# check all component except management, since we do not support management server on freebsd # check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/... time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use` # NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -v -p 1 ./client/... time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/... time go test -timeout 1m -failfast ./formatter/...

View File

@@ -30,7 +30,7 @@ jobs:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Get Go environment - name: Get Go environment
@@ -106,15 +106,15 @@ jobs:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -151,15 +151,15 @@ jobs:
needs: [ build-cache ] needs: [ build-cache ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment - name: Get Go environment
id: go-env id: go-env
run: | run: |
@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \ -e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \ -e CONTAINER=${CONTAINER} \
golang:1.25-alpine \ golang:1.23-alpine \
sh -c ' \ sh -c ' \
apk update; apk add --no-cache \ apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -220,15 +220,15 @@ jobs:
raceFlag: "-race" raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \ go test ${{ matrix.raceFlag }} \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/... -timeout 10m ./relay/... ./shared/relay/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"
@@ -270,15 +270,15 @@ jobs:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -321,15 +321,15 @@ jobs:
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -408,16 +408,15 @@ jobs:
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \ -p 9090:9090 \
prom/prometheus prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -498,15 +497,15 @@ jobs:
-p 9090:9090 \ -p 9090:9090 \
prom/prometheus prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -562,15 +561,15 @@ jobs:
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -24,7 +24,7 @@ jobs:
uses: actions/setup-go@v5 uses: actions/setup-go@v5
id: go id: go
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Get Go environment - name: Get Go environment

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
skip: go.mod,go.sum skip: go.mod,go.sum
golangci: golangci:
strategy: strategy:
@@ -46,16 +46,13 @@ jobs:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 uses: golangci/golangci-lint-action@v4
with: with:
version: latest version: latest
skip-cache: true args: --timeout=12m --out-format colored-line-number
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=12m

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
@@ -39,7 +39,7 @@ jobs:
- name: Setup NDK - name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build android netbird lib - name: build android netbird lib
@@ -56,9 +56,9 @@ jobs:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build iOS netbird lib - name: build iOS netbird lib

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.0" SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -19,102 +19,8 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release: release:
runs-on: ubuntu-latest-m runs-on: ubuntu-22.04
env: env:
flags: "" flags: ""
steps: steps:
@@ -134,7 +40,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4
@@ -230,7 +136,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4
@@ -294,7 +200,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4

View File

@@ -67,13 +67,10 @@ jobs:
- name: Install curl - name: Install curl
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4
@@ -83,6 +80,9 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Setup MySQL privileges - name: Setup MySQL privileges
if: matrix.store == 'mysql' if: matrix.store == 'mysql'
run: | run: |
@@ -243,7 +243,6 @@ jobs:
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
sleep 30 sleep 30
docker compose logs
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db

View File

@@ -14,27 +14,26 @@ jobs:
js_lint: js_lint:
name: "JS / Lint" name: "JS / Lint"
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
GOOS: js
GOARCH: wasm
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint - name: Install golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with: with:
version: latest version: latest
install-mode: binary install-mode: binary
skip-cache: true skip-cache: true
skip-save-cache: true skip-pkg-cache: true
cache-invalidation-interval: 0 skip-build-cache: true
working-directory: ./client - name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true continue-on-error: true
js_build: js_build:
@@ -46,7 +45,7 @@ jobs:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version-file: "go.mod" go-version: "1.23.x"
- name: Build Wasm client - name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env: env:
@@ -61,8 +60,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 57671680 ]; then if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!" echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1 exit 1
fi fi

1
.gitignore vendored
View File

@@ -31,4 +31,3 @@ infrastructure_files/setup-*.env
.DS_Store .DS_Store
vendor/ vendor/
/netbird /netbird
client/netbird-electron/

View File

@@ -1,124 +1,139 @@
version: "2" run:
linters: # Timeout for analysis, e.g. 30s, 5m.
default: none # Default: 1m
enable: timeout: 6m
- bodyclose
- dupword # This file contains only configs which differ from defaults.
- durationcheck # All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
- errcheck linters-settings:
- forbidigo errcheck:
- gocritic # Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
- gosec # Such cases aren't reported by default.
- govet # Default: false
- ineffassign check-type-assertions: false
- mirror
- misspell gosec:
- nilerr includes:
- nilnil - G101 # Look for hard coded credentials
- predeclared #- G102 # Bind to all interfaces
- revive - G103 # Audit the use of unsafe block
- sqlclosecheck - G104 # Audit errors not checked
- staticcheck - G106 # Audit the use of ssh.InsecureIgnoreHostKey
- unused #- G107 # Url provided to HTTP request as taint input
- wastedassign - G108 # Profiling endpoint automatically exposed on /debug/pprof
settings: - G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
errcheck: - G110 # Potential DoS vulnerability via decompression bomb
check-type-assertions: false - G111 # Potential directory traversal
gocritic: #- G112 # Potential slowloris attack
disabled-checks: - G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
- commentFormatting #- G114 # Use of net/http serve function that has no support for setting timeouts
- captLocal - G201 # SQL query construction using format string
- deprecatedComment - G202 # SQL query construction using string concatenation
gosec: - G203 # Use of unescaped data in HTML templates
includes: #- G204 # Audit use of command execution
- G101 - G301 # Poor file permissions used when creating a directory
- G103 - G302 # Poor file permissions used with chmod
- G104 - G303 # Creating tempfile using a predictable path
- G106 - G304 # File path provided as taint input
- G108 - G305 # File traversal when extracting zip/tar archive
- G109 - G306 # Poor file permissions used when writing to a new file
- G110 - G307 # Poor file permissions used when creating a file with os.Create
- G111 #- G401 # Detect the usage of DES, RC4, MD5 or SHA1
- G201 #- G402 # Look for bad TLS connection settings
- G202 - G403 # Ensure minimum RSA key length of 2048 bits
- G203 #- G404 # Insecure random number source (rand)
- G301 #- G501 # Import blocklist: crypto/md5
- G302 - G502 # Import blocklist: crypto/des
- G303 - G503 # Import blocklist: crypto/rc4
- G304 - G504 # Import blocklist: net/http/cgi
- G305 #- G505 # Import blocklist: crypto/sha1
- G306 - G601 # Implicit memory aliasing of items from a range statement
- G307 - G602 # Slice access out of bounds
- G403
- G502 gocritic:
- G503 disabled-checks:
- G504 - commentFormatting
- G601 - captLocal
- G602 - deprecatedComment
govet:
enable: govet:
- nilness # Enable all analyzers.
enable-all: false # Default: false
revive: enable-all: false
rules: enable:
- name: exported - nilness
arguments:
- checkPrivateReceivers revive:
- sayRepetitiveInsteadOfStutters
severity: warning
disabled: false
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules: rules:
- linters: - name: exported
- forbidigo severity: warning
path: management/cmd/root\.go disabled: false
- linters: arguments:
- forbidigo - "checkPrivateReceivers"
path: signal/cmd/root\.go - "sayRepetitiveInsteadOfStutters"
- linters: tenv:
- unused # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
path: sharedsock/filter\.go # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
- linters: # Default: false
- unused all: true
path: client/firewall/iptables/rule\.go
- linters: linters:
- gosec disable-all: true
- mirror enable:
path: test\.go ## enabled by default
- linters: - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- nilnil - gosimple # specializes in simplifying a code
path: mock\.go - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- linters: - ineffassign # detects when assignments to existing variables are not used
- staticcheck - staticcheck # is a go vet on steroids, applying a ton of static analysis checks
text: grpc.DialContext is deprecated - tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- linters: - typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- staticcheck - unused # checks for unused constants, variables, functions and types
text: grpc.WithBlock is deprecated ## disable by default but the have interesting results so lets add them
- linters: - bodyclose # checks whether HTTP response body is closed successfully
- staticcheck - dupword # dupword checks for duplicate words in the source code
text: "QF1001" - durationcheck # durationcheck checks for two durations multiplied together
- linters: - forbidigo # forbidigo forbids identifiers
- staticcheck - gocritic # provides diagnostics that check for bugs, performance and style issues
text: "QF1008" - gosec # inspects source code for security problems
- linters: - mirror # mirror reports wrong mirror patterns of bytes/strings usage
- staticcheck - misspell # misspess finds commonly misspelled English words in comments
text: "QF1012" - nilerr # finds the code that returns nil even if it checks that the error is not nil
paths: - nilnil # checks that there is no simultaneous return of nil error and an invalid value
- third_party$ - predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- builtin$ - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- examples$ - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues: issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5 max-same-issues: 5
formatters:
exclusions: exclude-rules:
generated: lax # allow fmt
paths: - path: management/cmd/root\.go
- third_party$ linters: forbidigo
- builtin$ - path: signal/cmd/root\.go
- examples$ linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -713,10 +713,8 @@ checksum:
extra_files: extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh - glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh
release: release:
extra_files: extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh - glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh

View File

@@ -136,14 +136,6 @@ checked out and set up:
go mod tidy go mod tidy
``` ```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support ### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers. If you prefer using a dev container for development, NetBird now includes support for dev containers.

View File

@@ -1,27 +0,0 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -38,11 +38,6 @@
</strong> </strong>
<br> <br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest"> <a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider New: NetBird terraform provider
</a> </a>
@@ -90,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Infrastructure requirements:** **Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory. - A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**. - The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- **Public domain** name pointing to the VM. - **Public domain** name pointing to the VM.
**Software requirements:** **Software requirements:**
@@ -103,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Steps** **Steps**
- Download and run the installation script: - Download and run the installation script:
```bash ```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
``` ```
- Once finished, you can manage the resources via `docker-compose` - Once finished, you can manage the resources via `docker-compose`
@@ -118,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups. [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle"> <p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/> <img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
</p> </p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details. See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2 FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

View File

@@ -4,13 +4,10 @@ package android
import ( import (
"context" "context"
"fmt"
"os" "os"
"slices" "slices"
"sync" "sync"
"golang.org/x/exp/maps"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
@@ -19,13 +16,10 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -59,6 +53,7 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string
tunAdapter device.TunAdapter tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover iFaceDiscover IFaceDiscover
recorder *peer.Status recorder *peer.Status
@@ -72,11 +67,12 @@ type Client struct {
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion) execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{ return &Client{
cfgFile: cfgFile,
deviceName: deviceName, deviceName: deviceName,
uiVersion: uiVersion, uiVersion: uiVersion,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
@@ -88,16 +84,10 @@ func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAd
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList) exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgFile, ConfigPath: c.cfgFile,
}) })
if err != nil { if err != nil {
return err return err
@@ -117,29 +107,23 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
c.ctxCancelLock.Unlock() c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg) auth := NewAuthWithConfig(ctx, cfg)
err = auth.login(urlOpener, isAndroidTV) err = auth.login(urlOpener)
if err != nil { if err != nil {
return err return err
} }
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps. // In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList) exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgFile, ConfigPath: c.cfgFile,
}) })
if err != nil { if err != nil {
return err return err
@@ -157,8 +141,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -172,19 +156,6 @@ func (c *Client) Stop() {
c.ctxCancel() c.ctxCancel()
} }
func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
return e.RenewTun(fd)
}
// SetTraceLogLevel configure the logger to trace level // SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() { func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
@@ -206,7 +177,6 @@ func (c *Client) PeersList() *PeerInfoArray {
p.IP, p.IP,
p.FQDN, p.FQDN,
p.ConnStatus.String(), p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
} }
peerInfos[n] = pi peerInfos[n] = pi
} }
@@ -231,43 +201,31 @@ func (c *Client) Networks() *NetworkArray {
return nil return nil
} }
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
log.Error("could not get route selector")
return nil
}
networkArray := &NetworkArray{ networkArray := &NetworkArray{
items: make([]Network, 0), items: make([]Network, 0),
} }
resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() { for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 { if len(routes) == 0 {
continue continue
} }
r := routes[0] r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String() netStr := r.Network.String()
if r.IsDynamic() { if r.IsDynamic() {
netStr = r.Domains.SafeString() netStr = r.Domains.SafeString()
} }
routePeer, err := c.recorder.GetPeer(routes[0].Peer) peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil { if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue continue
} }
network := Network{ network := Network{
Name: string(id), Name: string(id),
Network: netStr, Network: netStr,
Peer: routePeer.FQDN, Peer: peer.FQDN,
Status: routePeer.ConnStatus.String(), Status: peer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
} }
networkArray.Add(network) networkArray.Add(network)
} }
@@ -295,69 +253,6 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener() c.recorder.RemoveConnectionListener()
} }
func (c *Client) toggleRoute(command routeCommand) error {
return command.toggleRoute()
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}
engine := client.Engine()
if engine == nil {
return nil, fmt.Errorf("engine is not running")
}
manager := engine.GetRouteManager()
if manager == nil {
return nil, fmt.Errorf("could not get route manager")
}
return manager, nil
}
func (c *Client) SelectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
}
func (c *Client) DeselectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
}
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
// with its resolved IP addresses from the provided resolvedDomains map.
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
domains := NetworkDomains{}
for _, d := range route.Domains {
networkDomain := NetworkDomain{
Address: d.SafeString(),
}
if info, exists := resolvedDomains[d]; exists {
for _, prefix := range info.Prefixes {
networkDomain.addResolvedIP(prefix.Addr().String())
}
}
domains.Add(&networkDomain)
}
return domains
}
func exportEnvList(list *EnvList) { func exportEnvList(list *EnvList) {
if list == nil { if list == nil {
return return

View File

@@ -3,7 +3,15 @@ package android
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@@ -24,7 +32,7 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if // URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user // the backend want to show an url for the user
type URLOpener interface { type URLOpener interface {
Open(url string, userCode string) Open(string)
OnLoginSuccess() OnLoginSuccess()
} }
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
} }
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) supportsSSO := true
if err != nil { err := a.withBackOff(a.ctx, func() (err error) {
return false, fmt.Errorf("failed to create auth client: %v", err) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
} if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
defer authClient.Close() _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx) return err
if err != nil { }
return false, fmt.Errorf("failed to check SSO support: %v", err)
} return err
})
if !supportsSSO { if !supportsSSO {
return false, nil return false, nil
} }
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config) err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err return true, err
} }
@@ -108,26 +129,28 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
} }
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint //nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err, _ = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil { if err != nil {
return fmt.Errorf("login failed: %v", err) return fmt.Errorf("backoff cycle failed: %v", err)
} }
return profilemanager.WriteOutConfig(a.cfgPath, a.config) return profilemanager.WriteOutConfig(a.cfgPath, a.config)
} }
// Login try register the client on the server // Login try register the client on the server
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) { func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
go func() { go func() {
err := a.login(urlOpener, isAndroidTV) err := a.login(urlOpener)
if err != nil { if err != nil {
resultListener.OnError(err) resultListener.OnError(err)
} else { } else {
@@ -136,42 +159,50 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}() }()
} }
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { func (a *Auth) login(urlOpener URLOpener) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) var needsLogin bool
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token // check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx) err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil { if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err) return fmt.Errorf("backoff cycle failed: %v", err)
} }
jwtToken := "" jwtToken := ""
if needsLogin { if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV) tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
if err != nil { if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err) return fmt.Errorf("interactive sso login failed: %v", err)
} }
jwtToken = tokenInfo.GetTokenToUse() jwtToken = tokenInfo.GetTokenToUse()
} }
err, _ = authClient.Login(a.ctx, "", jwtToken) err = a.withBackOff(a.ctx, func() error {
if err != nil { err := internal.Login(a.ctx, a.config, "", jwtToken)
return fmt.Errorf("login failed: %v", err)
}
go urlOpener.OnLoginSuccess() if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil return nil
} }
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV) oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get OAuth flow: %v", err) return nil, err
} }
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -179,12 +210,24 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) go urlOpener.Open(flowInfo.VerificationURIComplete)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo) waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err) return nil, fmt.Errorf("waiting for browser login failed: %v", err)
} }
return &tokenInfo, nil return &tokenInfo, nil
} }
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -1,56 +0,0 @@
//go:build android
package android
import "fmt"
type ResolvedIPs struct {
resolvedIPs []string
}
func (r *ResolvedIPs) Add(ipAddress string) {
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
}
func (r *ResolvedIPs) Get(i int) (string, error) {
if i < 0 || i >= len(r.resolvedIPs) {
return "", fmt.Errorf("%d is out of range", i)
}
return r.resolvedIPs[i], nil
}
func (r *ResolvedIPs) Size() int {
return len(r.resolvedIPs)
}
type NetworkDomain struct {
Address string
resolvedIPs ResolvedIPs
}
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
d.resolvedIPs.Add(resolvedIP)
}
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
return &d.resolvedIPs
}
type NetworkDomains struct {
domains []*NetworkDomain
}
func (n *NetworkDomains) Add(domain *NetworkDomain) {
n.domains = append(n.domains, domain)
}
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
if i < 0 || i >= len(n.domains) {
return nil, fmt.Errorf("%d is out of range", i)
}
return n.domains[i], nil
}
func (n *NetworkDomains) Size() int {
return len(n.domains)
}

View File

@@ -3,16 +3,10 @@
package android package android
type Network struct { type Network struct {
Name string Name string
Network string Network string
Peer string Peer string
Status string Status string
IsSelected bool
Domains NetworkDomains
}
func (n Network) GetNetworkDomains() *NetworkDomains {
return &n.Domains
} }
type NetworkArray struct { type NetworkArray struct {

View File

@@ -1,5 +1,3 @@
//go:build android
package android package android
// PeerInfo describe information about the peers. It designed for the UI usage // PeerInfo describe information about the peers. It designed for the UI usage
@@ -7,11 +5,6 @@ type PeerInfo struct {
IP string IP string
FQDN string FQDN string
ConnStatus string // Todo replace to enum ConnStatus string // Todo replace to enum
Routes PeerRoutes
}
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
return &p.Routes
} }
// PeerInfoArray is a wrapper of []PeerInfo // PeerInfoArray is a wrapper of []PeerInfo

View File

@@ -1,20 +0,0 @@
//go:build android
package android
import "fmt"
type PeerRoutes struct {
routes []string
}
func (p *PeerRoutes) Get(i int) (string, error) {
if i < 0 || i >= len(p.routes) {
return "", fmt.Errorf("%d is out of range", i)
}
return p.routes[i], nil
}
func (p *PeerRoutes) Size() int {
return len(p.routes)
}

View File

@@ -1,10 +0,0 @@
//go:build android
package android
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
// at their default locations due to android OS restrictions.
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
}

View File

@@ -1,257 +0,0 @@
//go:build android
package android
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
// Android-specific config filename (different from desktop default.json)
defaultConfigFilename = "netbird.cfg"
// Subdirectory for non-default profiles (must match Java Preferences.java)
profilesSubdir = "profiles"
// Android uses a single user context per app (non-empty username required by ServiceManager)
androidUsername = "android"
)
// Profile represents a profile for gomobile
type Profile struct {
Name string
IsActive bool
}
// ProfileArray wraps profiles for gomobile compatibility
type ProfileArray struct {
items []*Profile
}
// Length returns the number of profiles
func (p *ProfileArray) Length() int {
return len(p.items)
}
// Get returns the profile at index i
func (p *ProfileArray) Get(i int) *Profile {
if i < 0 || i >= len(p.items) {
return nil
}
return p.items[i]
}
/*
/data/data/io.netbird.client/files/ ← configDir parameter
├── netbird.cfg ← Default profile config
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
└── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
// It wraps the internal profilemanager to provide Android-specific behavior
type ProfileManager struct {
configDir string
serviceMgr *profilemanager.ServiceManager
}
// NewProfileManager creates a new profile manager for Android
func NewProfileManager(configDir string) *ProfileManager {
// Set the default config path for Android (stored in root configDir, not profiles/)
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
// Set global paths for Android
profilemanager.DefaultConfigPathDir = configDir
profilemanager.DefaultConfigPath = defaultConfigPath
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
// Create ServiceManager with profiles/ subdirectory
// This avoids modifying the global ConfigDirOverride for profile listing
profilesDir := filepath.Join(configDir, profilesSubdir)
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
return &ProfileManager{
configDir: configDir,
serviceMgr: serviceMgr,
}
}
// ListProfiles returns all available profiles
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
// Convert internal profiles to Android Profile type
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
Name: p.Name,
IsActive: p.IsActive,
})
}
return &ProfileArray{items: profiles}, nil
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
config, err := profilemanager.ReadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to read profile config: %w", err)
}
// Clear authentication by removing private key and SSH key
config.PrivateKey = ""
config.SSHKey = ""
// Save config using internal profilemanager
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -1,67 +0,0 @@
//go:build android
package android
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/route"
)
func executeRouteToggle(id string, manager routemanager.Manager,
operationName string,
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
netID := route.NetID(id)
routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err)
}
manager.TriggerSelection(manager.GetClientRoutes())
return nil
}
type routeCommand interface {
toggleRoute() error
}
type selectRouteCommand struct {
route string
manager routemanager.Manager
}
func (s selectRouteCommand) toggleRoute() error {
routeSelector := s.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
return routeSelector.SelectRoutes(routes, true, allRoutes)
}
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
}
type deselectRouteCommand struct {
route string
manager routemanager.Manager
}
func (d deselectRouteCommand) toggleRoute() error {
routeSelector := d.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types" "github.com/netbirdio/netbird/upload-server/types"
) )
@@ -97,6 +98,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag, SystemInfo: systemInfoFlag,
LogFileCount: logFileCount, LogFileCount: logFileCount,
} }
@@ -134,7 +136,6 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0]) level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN { if level == proto.LogLevel_UNKNOWN {
//nolint
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0]) return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
} }
@@ -219,37 +220,21 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
cpuProfilingStarted := false headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil { statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = true
defer func() {
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
}
}
}()
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = false
}
}
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag, SystemInfo: systemInfoFlag,
LogFileCount: logFileCount, LogFileCount: logFileCount,
} }
@@ -316,6 +301,25 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error { func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -374,8 +378,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config, InternalConfig: config,
StatusRecorder: recorder, StatusRecorder: recorder,
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: logFilePath, LogFile: logFilePath,
CPUProfile: nil,
}, },
debug.BundleConfig{ debug.BundleConfig{
IncludeSystemInfo: true, IncludeSystemInfo: true,

View File

@@ -4,11 +4,14 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -80,7 +83,6 @@ var loginCmd = &cobra.Command{
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error { func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)
@@ -104,13 +106,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username, Username: &username,
} }
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey loginRequest.OptionalPreSharedKey = &preSharedKey
} }
@@ -206,7 +201,6 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
func switchProfile(ctx context.Context, profileName string, username string) error { func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)
@@ -247,7 +241,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err) return fmt.Errorf("read config file %s: %v", configFilePath, err)
} }
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -275,50 +269,54 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil return nil
} }
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "") err := WithBackOff(func() error {
if isAuthError { err := internal.Login(ctx, config, "", "")
needsLogin = true if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
} else if err != nil { needsLogin = true
return fmt.Errorf("login check failed: %v", err) return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
} }
jwtToken := "" jwtToken := ""
if setupKey == "" && needsLogin { if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
if err != nil { if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err) return fmt.Errorf("interactive sso login failed: %v", err)
} }
jwtToken = tokenInfo.GetTokenToUse() jwtToken = tokenInfo.GetTokenToUse()
} }
err, _ = authClient.Login(ctx, setupKey, jwtToken) var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
if err != nil { if err != nil {
return fmt.Errorf("login failed: %v", err) return fmt.Errorf("backoff cycle failed: %v", err)
} }
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
hint := "" oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -330,7 +328,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err) return nil, fmt.Errorf("waiting for browser login failed: %v", err)
} }
@@ -355,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {
if err := util.OpenBrowser(verificationURIComplete); err != nil { if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
} }
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -1,4 +1,5 @@
//go:build pprof //go:build pprof
// +build pprof
package cmd package cmd

View File

@@ -85,9 +85,6 @@ var (
// Execute executes the root command. // Execute executes the root command.
func Execute() error { func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute() return rootCmd.Execute()
} }
@@ -390,7 +387,6 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil { if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)

View File

@@ -259,7 +259,6 @@ func isServiceRunning() (bool, error) {
} }
const ( const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d" networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
@@ -274,16 +273,12 @@ ManageForeignRoutingPolicyRules=no
// configureSystemdNetworkd creates a drop-in configuration file to prevent // configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules. // systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error { func configureSystemdNetworkd() error {
if _, err := os.Stat(networkdConf); os.IsNotExist(err) { parentDir := filepath.Dir(networkdConfDir)
log.Debug("systemd-networkd not in use, skipping configuration") if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
return nil return nil
} }
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions // nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err) return fmt.Errorf("write networkd configuration: %w", err)

View File

@@ -1,176 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -1,276 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

View File

@@ -1,21 +0,0 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -1,220 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -1,74 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -14,9 +14,7 @@ import (
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client" sshclient "github.com/netbirdio/netbird/client/ssh/client"
@@ -36,7 +34,6 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding" enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding" enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth" disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
) )
var ( var (
@@ -50,8 +47,6 @@ var (
knownHostsFile string knownHostsFile string
identityFile string identityFile string
skipCachedToken bool skipCachedToken bool
requestPTY bool
sshNoBrowser bool
) )
var ( var (
@@ -61,7 +56,6 @@ var (
enableSSHLocalPortForward bool enableSSHLocalPortForward bool
enableSSHRemotePortForward bool enableSSHRemotePortForward bool
disableSSHAuth bool disableSSHAuth bool
sshJWTCacheTTL int
) )
func init() { func init() {
@@ -71,18 +65,14 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication") upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port") sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc) sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)") sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)") sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)") sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
@@ -107,9 +97,9 @@ SSH Options:
-p, --port int Remote SSH port (default 22) -p, --port int Remote SSH port (default 22)
-u, --user string SSH username -u, --user string SSH username
--login string SSH username (alias for --user) --login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true) --strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file -o, --known-hosts string Path to known_hosts file
-i, --identity string Path to SSH private key file
Examples: Examples:
netbird ssh peer-hostname netbird ssh peer-hostname
@@ -117,10 +107,8 @@ Examples:
netbird ssh --login root peer-hostname netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`, netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true, DisableFlagParsing: true,
@@ -155,10 +143,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx) sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() { go func() {
if err := runSSH(sshctx, host, cmd); err != nil { if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err cmd.Printf("Error: %v\n", err)
os.Exit(1)
} }
cancel() cancel()
}() }()
@@ -166,10 +154,6 @@ func sshFn(cmd *cobra.Command, args []string) error {
select { select {
case <-sig: case <-sig:
cancel() cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done(): case <-sshctx.Done():
} }
@@ -187,21 +171,6 @@ func getEnvOrDefault(flagName, defaultValue string) string {
return defaultValue return defaultValue
} }
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values // resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() { func resetSSHGlobals() {
port = sshserver.DefaultSSHPort port = sshserver.DefaultSSHPort
@@ -213,7 +182,6 @@ func resetSSHGlobals() {
strictHostKeyChecking = true strictHostKeyChecking = true
knownHostsFile = "" knownHostsFile = ""
identityFile = "" identityFile = ""
sshNoBrowser = false
} }
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args // parseCustomSSHFlags extracts -L, -R flags and returns filtered args
@@ -383,12 +351,10 @@ type sshFlags struct {
Port int Port int
Username string Username string
Login string Login string
RequestPTY bool
StrictHostKeyChecking bool StrictHostKeyChecking bool
KnownHostsFile string KnownHostsFile string
IdentityFile string IdentityFile string
SkipCachedToken bool SkipCachedToken bool
NoBrowser bool
ConfigPath string ConfigPath string
LogLevel string LogLevel string
LocalForwards []string LocalForwards []string
@@ -400,7 +366,6 @@ type sshFlags struct {
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath) defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil) fs.SetOutput(nil)
@@ -408,25 +373,22 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
flags := &sshFlags{} flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port") fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port") fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc) fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc) fs.String("user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)") fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking") fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file") fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file") fs.String("known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") fs.String("config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level") fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level") fs.String("log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags return fs, flags
} }
@@ -447,10 +409,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
fs, flags := createSSHFlagSet() fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil { if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) { return parseHostnameAndCommand(filteredArgs)
return nil
}
return err
} }
remaining := fs.Args() remaining := fs.Args()
@@ -465,12 +424,10 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
username = flags.Login username = flags.Login
} }
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken skipCachedToken = flags.SkipCachedToken
sshNoBrowser = flags.NoBrowser
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath configPath = flags.ConfigPath
@@ -530,7 +487,6 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
DaemonAddr: daemonAddr, DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken, SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking, InsecureSkipVerify: !strictHostKeyChecking,
NoBrowser: sshNoBrowser,
}) })
if err != nil { if err != nil {
@@ -564,29 +520,10 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
// executeSSHCommand executes a command over SSH. // executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error { func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil return nil
} }
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err) return fmt.Errorf("execute command: %w", err)
} }
return nil return nil
@@ -598,13 +535,6 @@ func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil return nil
} }
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err) return fmt.Errorf("open terminal: %w", err)
} }
return nil return nil
@@ -634,11 +564,7 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err return err
} }
if err := validateDestinationPort(remoteAddr); err != nil { cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
return fmt.Errorf("invalid remote address: %w", err)
}
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
go func() { go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -656,11 +582,7 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err return err
} }
if err := validateDestinationPort(localAddr); err != nil { cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
return fmt.Errorf("invalid local address: %w", err)
}
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
go func() { go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -671,35 +593,6 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil return nil
} }
// validateDestinationPort checks that the destination address has a valid port.
// Port 0 is only valid for bind addresses (where the OS picks an available port),
// not for destination addresses where we need to connect.
func validateDestinationPort(addr string) error {
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
return nil
}
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("parse address %s: %w", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port %s: %w", portStr, err)
}
if port == 0 {
return fmt.Errorf("port 0 is not valid for destination address")
}
if port < 0 || port > 65535 {
return fmt.Errorf("port %d out of range (1-65535)", port)
}
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". // parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". // Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) { func parsePortForwardSpec(spec string) (string, string, error) {
@@ -809,9 +702,7 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile logOutput = firstLogFile
} }
if err := util.InitLog(logLevel, logOutput); err != nil {
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err) return fmt.Errorf("init log: %w", err)
} }
@@ -823,23 +714,10 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid port: %s", portStr) return fmt.Errorf("invalid port: %s", portStr)
} }
// Check env var for browser setting since this command is invoked via SSH ProxyCommand proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
// where command-line flags cannot be passed. Default is to open browser.
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
if err != nil { if err != nil {
return fmt.Errorf("create SSH proxy: %w", err) return fmt.Errorf("create SSH proxy: %w", err)
} }
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil { if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err) return fmt.Errorf("SSH proxy: %w", err)
@@ -858,8 +736,7 @@ var sshDetectCmd = &cobra.Command{
} }
func sshDetectFn(cmd *cobra.Command, args []string) error { func sshDetectFn(cmd *cobra.Command, args []string) error {
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) if err := util.InitLog(logLevel, "console"); err != nil {
if err := util.InitLog(detectLogLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode()) os.Exit(detection.ServerTypeRegular.ExitCode())
} }
@@ -868,21 +745,15 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
if err != nil { if err != nil {
log.Debugf("invalid port %q: %v", portStr, err)
os.Exit(detection.ServerTypeRegular.ExitCode()) os.Exit(detection.ServerTypeRegular.ExitCode())
} }
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout) dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil { if err != nil {
log.Debugf("SSH server detection failed: %v", err)
cancel()
os.Exit(detection.ServerTypeRegular.ExitCode()) os.Exit(detection.ServerTypeRegular.ExitCode())
} }
cancel()
os.Exit(serverType.ExitCode()) os.Exit(serverType.ExitCode())
return nil return nil
} }

View File

@@ -8,7 +8,6 @@ import (
"io" "io"
"os" "os"
"os/user" "os/user"
"strings"
"github.com/pkg/sftp" "github.com/pkg/sftp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -52,7 +51,7 @@ func sftpMainDirect(cmd *cobra.Command) error {
if windowsDomain != "" { if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername) expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
} }
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) { if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
cmd.PrintErrf("user switching failed\n") cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail) os.Exit(sshserver.ExitCodeValidationFail)
} }

View File

@@ -667,51 +667,3 @@ func TestSSHCommand_ParameterIsolation(t *testing.T) {
}) })
} }
} }
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -99,17 +99,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name profName = activeProf.Name
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = outputInformationHolder.FullDetailSummary() statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
case jsonFlag: case jsonFlag:
statusOutputString, err = outputInformationHolder.JSON() statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
case yamlFlag: case yamlFlag:
statusOutputString, err = outputInformationHolder.YAML() statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default: default:
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false) statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
} }
if err != nil { if err != nil {
@@ -124,7 +124,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)

View File

@@ -13,13 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
@@ -27,6 +20,8 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -89,7 +84,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
@@ -98,8 +97,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
peersmanager := peers.NewManager(store, permissionsManagerMock) peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl) settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -113,21 +110,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
ctx := context.Background() accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil { mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -197,10 +197,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String()) r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus() r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r, false) connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "") SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles)) return connectClient.Run(nil)
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
@@ -216,7 +216,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)
@@ -287,13 +286,6 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse
@@ -363,18 +355,14 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.EnableSSHSFTP = &enableSSHSFTP req.EnableSSHSFTP = &enableSSHSFTP
} }
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
} }
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
} }
if cmd.Flag(disableSSHAuthFlag).Changed { if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth req.DisableSSHAuth = &disableSSHAuth
} }
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err) log.Errorf("parse interface name: %v", err)
@@ -479,10 +467,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth ic.DisableSSHAuth = &disableSSHAuth
} }
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return nil, err return nil, err
@@ -603,11 +587,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth loginRequest.DisableSSHAuth = &disableSSHAuth
} }
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed { if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled loginRequest.DisableAutoConnect = &autoConnectDisabled
} }

View File

@@ -1,13 +0,0 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -1,75 +0,0 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -16,12 +16,10 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh" sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
var ( var (
@@ -40,7 +38,6 @@ type Client struct {
setupKey string setupKey string
jwtToken string jwtToken string
connect *internal.ConnectClient connect *internal.ConnectClient
recorder *peer.Status
} }
// Options configures a new Client. // Options configures a new Client.
@@ -69,8 +66,6 @@ type Options struct {
StatePath string StatePath string
// DisableClientRoutes disables the client routes // DisableClientRoutes disables the client routes
DisableClientRoutes bool DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
} }
// validateCredentials checks that exactly one credential type is provided // validateCredentials checks that exactly one credential type is provided
@@ -139,7 +134,6 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey, PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t, DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
} }
if opts.ConfigPath != "" { if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input) config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -167,40 +161,26 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error { func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.connect != nil { if c.cancel != nil {
return ErrClientAlreadyStarted return ErrClientAlreadyStarted
} }
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background())) ctx := internal.CtxInitState(context.Background())
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err) return fmt.Errorf("login: %w", err)
} }
recorder := peer.NewRecorder(c.config.ManagementURL.String()) recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder client := internal.NewConnectClient(ctx, c.config, recorder)
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan struct{}) run := make(chan struct{})
clientErr := make(chan error, 1) clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run, ""); err != nil { if err := client.Run(run); err != nil {
clientErr <- err clientErr <- err
} }
}() }()
@@ -217,7 +197,6 @@ func (c *Client) Start(startCtx context.Context) error {
} }
c.connect = client c.connect = client
c.cancel = cancel
return nil return nil
} }
@@ -232,23 +211,17 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted return ErrClientNotStarted
} }
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1) done := make(chan error, 1)
connect := c.connect
go func() { go func() {
done <- connect.Stop() done <- c.connect.Stop()
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.connect = nil c.cancel = nil
return ctx.Err() return ctx.Err()
case err := <-done: case err := <-done:
c.connect = nil c.cancel = nil
if err != nil { if err != nil {
return fmt.Errorf("stop: %w", err) return fmt.Errorf("stop: %w", err)
} }
@@ -342,62 +315,6 @@ func (c *Client) NewHTTPClient() *http.Client {
} }
} }
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys. // VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network, // Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures. // ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -1,14 +1,13 @@
package iptables package iptables
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"slices" "slices"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid" "github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go" "github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -41,13 +40,19 @@ type aclManager struct {
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
}, nil }
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return m, nil
} }
func (m *aclManager) init(stateManager *statemanager.Manager) error { func (m *aclManager) init(stateManager *statemanager.Manager) error {
@@ -93,8 +98,8 @@ func (m *aclManager) AddPeerFiltering(
specs = append(specs, "-j", actionToStr(action)) specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err) return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
} }
// if ruleset already exists it means we already have the firewall rule // if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
@@ -108,18 +113,14 @@ func (m *aclManager) AddPeerFiltering(
}}, nil }}, nil
} }
if err := m.flushIPSet(ipsetName); err != nil { if err := ipset.Flush(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) { log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
} }
if err := m.createIPSet(ipsetName); err != nil { if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err) return nil, fmt.Errorf("failed to create ipset: %w", err)
} }
if err := m.addToIPSet(ipsetName, ip); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err) return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
} }
ipList := newIpList(ip.String()) ipList := newIpList(ip.String())
@@ -171,16 +172,11 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset // delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok { if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip) if err := ipset.Del(r.ipsetName, r.ip); err != nil {
if ip == nil { return fmt.Errorf("failed to delete ip from ipset: %w", err)
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
} }
delete(ipsetList.ips, r.ip) delete(ipsetList.ips, r.ip)
} }
@@ -194,7 +190,10 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
// we delete last IP from the set, that means we need to delete // we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too // set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName) m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
} }
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
@@ -207,16 +206,6 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState() m.updateState()
return nil return nil
@@ -275,19 +264,11 @@ func (m *aclManager) cleanChains() error {
} }
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil { if err := ipset.Flush(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) { log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
} }
if err := m.destroyIPSet(ipsetName); err != nil { if err := ipset.Destroy(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
} }
m.ipsetStore.deleteIpset(ipsetName) m.ipsetStore.deleteIpset(ipsetName)
} }
@@ -386,8 +367,11 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0 matchByIP := true
matchByIP := !ip.IsUnspecified() // don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
if matchByIP { if matchByIP {
if ipsetName != "" { if ipsetName != "" {
@@ -432,61 +416,3 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ipsetName + actionSuffix return ipsetName + actionSuffix
} }
} }
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -83,10 +83,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err) return fmt.Errorf("acl manager init: %w", err)
} }
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
@@ -181,10 +177,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error var merr *multierror.Error
if err := m.cleanupNoTrackChain(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if err := m.aclMgr.Reset(); err != nil { if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
@@ -285,125 +277,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
} }
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
// Egress rules: match outgoing loopback UDP packets
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
return fmt.Errorf("add output sport notrack rule: %w", err)
}
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
return fmt.Errorf("add output dport notrack rule: %w", err)
}
// Ingress rules: match incoming loopback UDP packets
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
}
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
}
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("create chain: %w", err)
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add prerouting jump rule: %w", err)
}
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
return nil
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("clear and delete chain: %w", err)
}
return nil
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule) t.Logf(" [%d] %s", i, rule)
} }
var denyRuleIndex, acceptRuleIndex = -1, -1 var denyRuleIndex, acceptRuleIndex int = -1, -1
for i, rule := range rules { for i, rule := range rules {
if strings.Contains(rule, "DROP") { if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule) t.Logf("Found DROP rule at index %d: %s", i, rule)

View File

@@ -10,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
ipset "github.com/lrh3321/ipset-go" "github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
@@ -107,6 +107,10 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
}, },
) )
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return r, nil return r, nil
} }
@@ -228,12 +232,12 @@ func (r *router) findSets(rule []string) []string {
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err) return fmt.Errorf("create set %s: %w", setName, err)
} }
for _, prefix := range sources { for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil { if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err) return fmt.Errorf("add element to set %s: %w", setName, err)
} }
} }
@@ -242,7 +246,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
} }
func (r *router) deleteIpSet(setName string) error { func (r *router) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil { if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err) return fmt.Errorf("destroy set %s: %w", setName, err)
} }
@@ -911,8 +915,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue continue
} }
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
} }
} }
if merr == nil { if merr == nil {
@@ -989,37 +993,3 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))} return []string{flag, strconv.Itoa(int(port.Values[0]))}
} }
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -168,10 +168,6 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule // RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,10 +48,8 @@ type Manager struct {
rConn *nftables.Conn rConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
router *router router *router
aclManager *AclManager aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
} }
// Create nftables firewall manager // Create nftables firewall manager
@@ -94,10 +91,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err) return fmt.Errorf("acl manager init: %w", err)
} }
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation. // We only need to record minimal interface state for potential recreation.
@@ -295,15 +288,7 @@ func (m *Manager) Flush() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.aclManager.Flush(); err != nil { return m.aclManager.Flush()
return err
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
return nil
} }
// AddDNATRule adds a DNAT rule // AddDNATRule adds a DNAT rule
@@ -346,176 +331,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
} }
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
return fmt.Errorf("notrack chains not initialized")
}
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
loopback := []byte{127, 0, 0, 1}
// Egress rules: match outgoing loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
// Ingress rules: match incoming loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
&expr.Counter{},
&expr.Notrack{},
},
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush notrack rules: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRaw,
})
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawPrerouting,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRaw,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush chain creation: %w", err)
}
return nil
}
func (m *Manager) refreshNoTrackChains() error {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
tableName := getTableName()
for _, c := range chains {
if c.Table.Name != tableName {
continue
}
switch c.Name {
case chainNameRawOutput:
m.notrackOutputChain = c
case chainNameRawPrerouting:
m.notrackPreroutingChain = c
}
}
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
t.Logf("Found %d rules in nftables chain", len(rules)) t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept // Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex = -1, -1 var acceptRuleIndex, denyRuleIndex int = -1, -1
for i, rule := range rules { for i, rule := range rules {
hasAcceptHTTPSet := false hasAcceptHTTPSet := false
hasDenyHTTPSet := false hasDenyHTTPSet := false
@@ -208,13 +208,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
for _, e := range rule.Exprs { for _, e := range rule.Exprs {
// Check for set lookup // Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok { if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName { if lookup.SetName == "accept-http" {
case "accept-http":
hasAcceptHTTPSet = true hasAcceptHTTPSet = true
case "deny-http": } else if lookup.SetName == "deny-http" {
hasDenyHTTPSet = true hasDenyHTTPSet = true
} }
} }
// Check for port 80 // Check for port 80
if cmp, ok := e.(*expr.Cmp); ok { if cmp, ok := e.(*expr.Cmp); ok {
@@ -224,10 +222,9 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
} }
// Check for verdict // Check for verdict
if verdict, ok := e.(*expr.Verdict); ok { if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind { if verdict.Kind == expr.VerdictAccept {
case expr.VerdictAccept:
action = "ACCEPT" action = "ACCEPT"
case expr.VerdictDrop: } else if verdict.Kind == expr.VerdictDrop {
action = "DROP" action = "DROP"
} }
} }
@@ -389,97 +386,6 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
const octet2Count = 25
const octet3Count = 255
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
for i := 1; i < octet2Count; i++ {
for j := 1; j < octet3Count; j++ {
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) { func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper() t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch") require.Equal(t, len(got), len(want), "expression count mismatch")

View File

@@ -27,11 +27,7 @@ import (
) )
const ( const (
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING" chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingNat = "netbird-rt-postrouting"
@@ -48,12 +44,10 @@ const (
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40 ipTCPHeaderMinSize = 40
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
) )
const refreshRulesMapError = "refresh rules map: %w"
var ( var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found") errFilterTableNotFound = fmt.Errorf("'filter' table not found")
) )
@@ -97,7 +91,11 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error var err error
r.filterTable, err = r.loadFilterTable() r.filterTable, err = r.loadFilterTable()
if err != nil { if err != nil {
log.Debugf("ip filter table not found: %v", err) if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
} }
return r, nil return r, nil
@@ -177,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("list tables: %w", err) return nil, fmt.Errorf("unable to list tables: %v", err)
} }
for _, table := range tables { for _, table := range tables {
@@ -189,39 +187,14 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound return nil, errFilterTableNotFound
} }
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error { func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw, Name: chainNameRoutingFw,
Table: r.workTable, Table: r.workTable,
}) })
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1 prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat, Name: chainNameRoutingNat,
@@ -263,12 +236,9 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
}) })
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) // Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
r.addPostroutingRules() return fmt.Errorf("add single nat rule: %v", err)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
} }
if err := r.addMSSClampingRules(); err != nil { if err := r.addMSSClampingRules(); err != nil {
@@ -280,7 +250,11 @@ func (r *router) createContainers() error {
} }
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
} }
return nil return nil
@@ -515,35 +489,16 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
} }
elements := convertPrefixesToSet(prefixes) elements := convertPrefixesToSet(prefixes)
nElements := len(elements) if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err) return nil, fmt.Errorf("flush error: %w", err)
} }
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
var subEnd int log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd = min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil return nfset, nil
} }
@@ -740,7 +695,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
// addPostroutingRules adds the masquerade rules // addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() { func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface // First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{ exprs := []expr.Any{
// Match on the first fwmark // Match on the first fwmark
@@ -806,6 +761,8 @@ func (r *router) addPostroutingRules() {
Chain: r.chains[chainNameRoutingNat], Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2, Exprs: exprs2,
}) })
return nil
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
@@ -882,7 +839,7 @@ func (r *router) addMSSClampingRules() error {
Exprs: exprsOut, Exprs: exprsOut,
}) })
return r.conn.Flush() return nil
} }
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
@@ -982,21 +939,8 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface. // This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error { func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil { if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil return nil
} }
@@ -1009,11 +953,11 @@ func (r *router) acceptFilterTableRules() error {
// Try iptables first and fallback to nftables if iptables is not available // Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
// iptables is not available but the filter table exists // filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables" fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable) return r.acceptFilterRulesNftables()
} }
return r.acceptFilterRulesIptables(ipt) return r.acceptFilterRulesIptables(ipt)
@@ -1024,7 +968,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err)) merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
} else { } else {
log.Debugf("added iptables forward rule: %v", rule) log.Debugf("added iptables forward rule: %v", rule)
} }
@@ -1032,7 +976,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule() inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err)) merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
} else { } else {
log.Debugf("added iptables input rule: %v", inputRule) log.Debugf("added iptables input rule: %v", inputRule)
} }
@@ -1052,70 +996,18 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
} }
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables. func (r *router) acceptFilterRulesNftables() error {
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{ iifRule := &nftables.Rule{
Table: chain.Table, Table: r.filterTable,
Chain: chain, Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -1138,19 +1030,30 @@ func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
Data: intf, Data: intf,
}, },
} }
oifRule := &nftables.Rule{ oifRule := &nftables.Rule{
Table: chain.Table, Table: r.filterTable,
Chain: chain, Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...), Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif), UserData: []byte(userDataAcceptForwardRuleOif),
} }
r.conn.InsertRule(oifRule) r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{ inputRule := &nftables.Rule{
Table: chain.Table, Table: r.filterTable,
Chain: chain, Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -1164,44 +1067,32 @@ func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
UserData: []byte(userDataAcceptInputRule), UserData: []byte(userDataAcceptInputRule),
} }
r.conn.InsertRule(inputRule) r.conn.InsertRule(inputRule)
return nil
} }
func (r *router) removeAcceptFilterRules() error { func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil { if r.filterTable == nil {
return nil return nil
} }
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable) return r.removeAcceptFilterRulesNftables()
} }
return r.removeAcceptFilterRulesIptables(ipt) return r.removeAcceptFilterRulesIptables(ipt)
} }
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family) chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return fmt.Errorf("list chains: %v", err) return fmt.Errorf("list chains: %v", err)
} }
for _, chain := range chains { for _, chain := range chains {
if chain.Table.Name != table.Name { if chain.Table.Name != r.filterTable.Name {
continue continue
} }
@@ -1209,101 +1100,27 @@ func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
continue continue
} }
if err := r.removeAcceptRulesFromChain(table, chain); err != nil { rules, err := r.conn.GetRules(r.filterTable, chain)
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil { if err != nil {
log.Debugf("list chains for family %d: %v", family, err) return fmt.Errorf("get rules: %v", err)
continue
} }
for _, chain := range allChains { for _, rule := range rules {
if r.isExternalChain(chain) { if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
chains = append(chains, chain) bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
} }
} }
} }
return chains if err := r.conn.Flush(); err != nil {
} return fmt.Errorf(flushError, err)
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
} }
// Skip all iptables-managed tables in the ip family return nil
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
} }
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1311,13 +1128,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err)) merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
} }
} }
inputRule := r.getAcceptInputRule() inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err)) merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
@@ -1379,7 +1196,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains { for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain) rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil { if err != nil {
return fmt.Errorf("list rules: %w", err) return fmt.Errorf(" unable to list rules: %v", err)
} }
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {

View File

@@ -29,7 +29,7 @@ import (
) )
const ( const (
layerTypeAll = 255 layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40 ipTCPHeaderMinSize = 40
@@ -262,7 +262,10 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix := iface.Address().Network wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering( rule, err := m.addRouteFiltering(
@@ -436,7 +439,19 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort r.sPort = sPort
r.dPort = dPort r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer) switch proto {
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock() m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet var targetMap map[netip.Addr]RuleSet
@@ -481,17 +496,16 @@ func (m *Manager) addRouteFiltering(
} }
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), proto: proto,
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
} }
if destination.IsPrefix() { if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix} rule.destinations = []netip.Prefix{destination.Prefix}
@@ -570,14 +584,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set // UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating. // by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -789,7 +795,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
pseudoSum += uint32(d.ip4.Protocol) pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength) pseudoSum += uint32(tcpLength)
var sum = pseudoSum var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 { for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
} }
@@ -939,7 +945,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData) ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked { if blocked {
pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -1004,22 +1010,20 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false return false
} }
protoLayer := d.decoded[1] proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort) ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass { if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: ruleID, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: proto, Protocol: pnum,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
@@ -1048,33 +1052,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return nftypes.TCP return firewall.ProtocolTCP, nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return nftypes.UDP return firewall.ProtocolUDP, nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return nftypes.ICMP return firewall.ProtocolICMP, nftypes.ICMP
default: default:
return nftypes.ProtocolUnknown return firewall.ProtocolALL, nftypes.ProtocolUnknown
} }
} }
@@ -1246,30 +1233,19 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
} }
// routeACLsPass returns true if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches { if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
destMatched := false destMatched := false
for _, dst := range rule.destinations { for _, dst := range rule.destinations {
if dst.Contains(dstAddr) { if dst.Contains(dstAddr) {
@@ -1288,8 +1264,21 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break break
} }
} }
if !sourceMatched {
return false
}
return sourceMatched if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
} }
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
} }
} }
} }

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100") dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything) // Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic") require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err) require.NoError(t, err)
// Now the packet should be allowed // Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
} }

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100") dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100") dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Check that all original prefixes are still included // Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100") dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50") dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases { for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc) require.Equal(t, tc.expected, isAllowed, tc.desc)
} }
} }

View File

@@ -2,7 +2,6 @@ package forwarder
import ( import (
"fmt" "fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
@@ -17,7 +16,7 @@ type endpoint struct {
logger *nblog.Logger logger *nblog.Logger
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
device *wgdevice.Device device *wgdevice.Device
mtu atomic.Uint32 mtu uint32
} }
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -29,7 +28,7 @@ func (e *endpoint) IsAttached() bool {
} }
func (e *endpoint) MTU() uint32 { func (e *endpoint) MTU() uint32 {
return e.mtu.Load() return e.mtu
} }
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -83,22 +82,6 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true return true
} }
func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}
func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}
func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}
type epID stack.TransportEndpointID type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {

View File

@@ -7,7 +7,6 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -36,16 +35,14 @@ type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection // ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip tcpip.Address ip tcpip.Address
netstack bool netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -63,8 +60,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{ endpoint := &endpoint{
logger: logger, logger: logger,
device: iface.GetWGDevice(), device: iface.GetWGDevice(),
mtu: uint32(mtu),
} }
endpoint.mtu.Store(uint32(mtu))
if err := s.CreateNIC(nicID, endpoint); err != nil { if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err) return nil, fmt.Errorf("create NIC: %v", err)
@@ -106,16 +103,15 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger, logger: logger,
flowLogger: flowLogger, flowLogger: flowLogger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger), udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -133,8 +129,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
log.Debugf("forwarder: Initialization complete with NIC %d", nicID) log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil return f, nil
} }
@@ -204,24 +198,3 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort, DstPort: dstPort,
} }
} }
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -2,11 +2,8 @@ package forwarder
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec"
"runtime"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -17,95 +14,30 @@ import (
) )
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
flowID := uuid.New() if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0) // dont process our own replies
// For Echo Requests, send and wait for response
if icmpHdr.Type() == header.ICMPv4Echo {
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
if !f.hasRawICMPAccess {
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true return true
} }
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
return true flowID := uuid.New()
} f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting. ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return true
}
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel() defer cancel()
lc := net.ListenConfig{} lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err) f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
}
dstIP := f.determineDialAddr(id.LocalAddress) // This will make netstack reply on behalf of the original destination, that's ok for now
dst := &net.IPAddr{IP: dstIP} return false
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
}
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
} }
defer func() { defer func() {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
@@ -113,22 +45,38 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
} }
}() }()
txBytes := f.handleEchoResponse(conn, id) dstIP := f.determineDialAddr(id.LocalAddress)
rtt := time.Since(sendTime).Round(10 * time.Microsecond) dst := &net.IPAddr{IP: dstIP}
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", fullPacket := stack.PayloadSince(pkt.TransportHeader())
epID(id), icmpType, icmpCode, rtt) payload := fullPacket.AsSlice()
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
} }
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0 return 0
} }
response := make([]byte, f.endpoint.mtu.Load()) response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response) n, _, err := conn.ReadFrom(response)
if err != nil { if err != nil {
if !isTimeout(err) { if !isTimeout(err) {
@@ -137,7 +85,31 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
return 0 return 0
} }
return f.injectICMPReply(id, response[:n]) ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
@@ -180,95 +152,3 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)
} }
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
// This is used as a fallback when raw socket access is not available.
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
// Returns the size of the injected packet.
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyICMPHdr := header.ICMPv4(replyICMP)
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
return f.injectICMPReply(id, replyICMP)
}
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
// Returns the total size of the injected packet, or 0 if injection failed.
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
return 0
}
return len(fullPacket)
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
@@ -132,10 +131,10 @@ func (f *udpForwarder) cleanup() {
} }
// handleUDP is called by the UDP forwarder for new packets // handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil { if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet") f.logger.Trace("forwarder: context done, dropping UDP packet")
return false return
} }
id := r.ID() id := r.ID()
@@ -145,7 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return true return
} }
flowID := uuid.New() flowID := uuid.New()
@@ -163,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err != nil { if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return false return
} }
// Create wait queue for blocking syscalls // Create wait queue for blocking syscalls
@@ -174,10 +173,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return false return
} }
inConn := gonet.NewUDPConn(&wq, ep) inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx) connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{ pConn := &udpPacketConn{
@@ -200,7 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return true return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
@@ -209,7 +208,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
return true
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
@@ -350,7 +348,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
} }
func isClosedError(err error) bool { func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
} }
func isTimeout(err error) bool { func isTimeout(err error) bool {

View File

@@ -130,7 +130,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ { for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
} }

View File

@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// nolint:gosimple // nolint:gosimple
_ = mapManager.localIPs[ip.String()] _, _ = mapManager.localIPs[ip.String()]
} }
}) })
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// nolint:gosimple // nolint:gosimple
_ = mapManager.localIPs[ip.String()] _, _ = mapManager.localIPs[ip.String()]
} }
}) })
} }

View File

@@ -168,15 +168,6 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
} }
} }
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {

View File

@@ -234,10 +234,9 @@ func TestInboundPortDNATNegative(t *testing.T) {
require.False(t, translated, "Packet should NOT be translated for %s", tc.name) require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet) d = parsePacket(t, packet)
switch tc.protocol { if tc.protocol == layers.IPProtocolTCP {
case layers.IPProtocolTCP:
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
case layers.IPProtocolUDP: } else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
} }
}) })

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix sources []netip.Prefix
dstSet firewall.Set dstSet firewall.Set
destinations []netip.Prefix destinations []netip.Prefix
protoLayer gopacket.LayerType proto firewall.Protocol
srcPort *firewall.Port srcPort *firewall.Port
dstPort *firewall.Port dstPort *firewall.Port
action firewall.Action action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
protoLayer := d.decoded[1] proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id) strId := string(id)
if id == nil { if id == nil {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"runtime" "runtime"
"time" "time"
@@ -11,6 +12,7 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -18,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
) )
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls // Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff { func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff() b := backoff.NewExponentialBackOff()
@@ -26,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options. // CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -43,22 +68,25 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
})) }))
} }
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) conn, err := grpc.NewClient(
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(tlsEnabled, component), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}), }),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("dial context: %w", err) return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
} }
return conn, nil return conn, nil

View File

@@ -1,169 +0,0 @@
package bind
import (
"errors"
"net"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
// to the appropriate connection based on the destination address.
// ReadFrom is not used in the hot path - ICEBind receives packets via
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
readFromWarn sync.Once
}
// NewDualStackPacketConn creates a new dual-stack packet connection.
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from the available connection (preferring IPv4).
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
d.readFromWarn.Do(func() {
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
})
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type.
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections.
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection if available,
// otherwise the IPv6 connection.
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections.
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections.
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections.
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -1,119 +0,0 @@
package bind
import (
"net"
"testing"
)
var (
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
payload = make([]byte, 1200)
)
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = conn.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
ds := NewDualStackPacketConn(conn, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn.Close()
ds := NewDualStackPacketConn(nil, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
addrs := []net.Addr{ipv4Addr, ipv6Addr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, addrs[i&1])
}
}

View File

@@ -1,191 +0,0 @@
package bind
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
ipv4Conn := &mockPacketConn{network: "udp4"}
ipv6Conn := &mockPacketConn{network: "udp6"}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
tests := []struct {
name string
addr *net.UDPAddr
wantSocket string
}{
{
name: "IPv4 address",
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 address",
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantSocket: "udp6",
},
{
name: "IPv4-mapped IPv6 goes to IPv4",
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv4 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
wantSocket: "udp6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipv4Conn.writeCount = 0
ipv6Conn.writeCount = 0
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
require.NoError(t, err)
assert.Equal(t, 4, n)
if tt.wantSocket == "udp4" {
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
} else {
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
}
})
}
}
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
// IPv4 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.NoError(t, err)
// IPv6 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv6 connection")
}
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
// IPv6 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.NoError(t, err)
// IPv4 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv4 connection")
}
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
// only reads from one socket (IPv4 preferred). This is fine because the actual
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
ipv4Conn := &mockPacketConn{
network: "udp4",
readData: []byte("from ipv4"),
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
}
ipv6Conn := &mockPacketConn{
network: "udp6",
readData: []byte("from ipv6"),
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
buf := make([]byte, 100)
n, addr, err := dualStack.ReadFrom(buf)
require.NoError(t, err)
// reads from IPv4 (preferred) - this is expected behavior
assert.Equal(t, "from ipv4", string(buf[:n]))
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
}
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
tests := []struct {
name string
ipv4 net.PacketConn
ipv6 net.PacketConn
wantAddr net.Addr
}{
{
name: "both available returns IPv4",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv4Addr,
},
{
name: "IPv4 only",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: nil,
wantAddr: ipv4Addr,
},
{
name: "IPv6 only",
ipv4: nil,
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv6Addr,
},
{
name: "neither returns nil",
ipv4: nil,
ipv6: nil,
wantAddr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
})
}
}
// mock
type mockPacketConn struct {
network string
writeCount int
readData []byte
readAddr net.Addr
localAddr net.Addr
}
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if m.readData != nil {
return copy(b, m.readData), m.readAddr, nil
}
return 0, nil, nil
}
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
m.writeCount++
return len(b), nil
}
func (m *mockPacketConn) Close() error { return nil }
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@@ -14,6 +14,7 @@ import (
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
@@ -26,8 +27,8 @@ type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
} }
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool) return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
} }
// ICEBind is a bind implementation with two main features: // ICEBind is a bind implementation with two main features:
@@ -57,8 +58,6 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
} }
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
@@ -104,12 +103,6 @@ func (s *ICEBind) Close() error {
close(s.closedChan) close(s.closedChan)
s.muUDPMux.Lock()
s.ipv4Conn = nil
s.ipv6Conn = nil
s.udpMux = nil
s.muUDPMux.Unlock()
return s.StdNetBind.Close() return s.StdNetBind.Close()
} }
@@ -167,18 +160,19 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil return nil
} }
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
// Detect IPv4 vs IPv6 from connection's local address s.udpMux = udpmux.NewUniversalUDPMuxDefault(
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil { udpmux.UniversalUDPMuxParams{
s.ipv4Conn = conn UDPConn: nbnet.WrapPacketConn(conn),
} else { Net: s.transportNet,
s.ipv6Conn = conn FilterFn: s.filterFn,
} WGAddress: s.address,
s.createOrUpdateMux() MTU: s.mtu,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool) msgs := getMessages(msgsPool)
for i := range bufs { for i := range bufs {
@@ -186,13 +180,12 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
} }
defer putMessages(msgs, msgsPool) defer putMessages(msgs, msgsPool)
var numMsgs int var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload { if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams) readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint:staticcheck //nolint
_, err = pc.ReadBatch((*msgs)[readAt:], 0) numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -214,12 +207,12 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
} }
numMsgs = 1 numMsgs = 1
} }
for i := 0; i < numMsgs; i++ { for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i] msg := &(*msgs)[i]
// todo: handle err // todo: handle err
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok { ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue continue
} }
sizes[i] = msg.N sizes[i] = msg.N
@@ -240,38 +233,6 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
} }
} }
// createOrUpdateMux creates or updates the UDP mux with the available connections.
// Must be called with muUDPMux held.
func (s *ICEBind) createOrUpdateMux() {
var muxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
muxConn = NewDualStackPacketConn(
nbnet.WrapPacketConn(s.ipv4Conn),
nbnet.WrapPacketConn(s.ipv6Conn),
)
case s.ipv4Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
case s.ipv6Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
default:
return
}
// Don't close the old mux - it doesn't own the underlying connections.
// The sockets are managed by WireGuard's StdNetBind, not by us.
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers { for i := range buffers {
if !stun.IsMessage(buffers[i]) { if !stun.IsMessage(buffers[i]) {
@@ -284,14 +245,9 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err return true, err
} }
s.muUDPMux.Lock() muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
mux := s.udpMux if muxErr != nil {
s.muUDPMux.Unlock() log.Warnf("failed to handle STUN packet")
if mux != nil {
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
log.Warnf("failed to handle STUN packet: %v", muxErr)
}
} }
buffers[i] = []byte{} buffers[i] = []byte{}

View File

@@ -1,324 +0,0 @@
package bind
import (
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, ipv6Conn := createDualStackConns(t)
defer ipv4Conn.Close()
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
pool := createMsgPool()
// Simulate wireguard-go calling CreateReceiverFn for IPv4
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
require.NotNil(t, ipv4RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
iceBind.muUDPMux.Unlock()
// Simulate wireguard-go calling CreateReceiverFn for IPv6
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
require.NotNil(t, ipv6RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer ipv4Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn)
assert.Nil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.Nil(t, iceBind.ipv4Conn)
assert.NotNil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
// with peers on different address families through the same DualStackPacketConn.
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
// two "remote peers" listening on different address families
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
// our local dual-stack connection
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
// send to both peers
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
require.NoError(t, err)
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
require.NoError(t, err)
// verify IPv4 peer got its packet from the IPv4 socket
buf := make([]byte, 100)
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := ipv4Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv4", string(buf[:n]))
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
// verify IPv6 peer got its packet from the IPv6 socket
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err = ipv6Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv6", string(buf[:n]))
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
}
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
const packetsPerFamily = 500
ipv4Received := make(chan string, packetsPerFamily)
ipv6Received := make(chan string, packetsPerFamily)
startGate := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv4Peer.ReadFrom(buf)
if err != nil {
return
}
ipv4Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv6Peer.ReadFrom(buf)
if err != nil {
return
}
ipv6Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
}
}()
close(startGate)
time.AfterFunc(5*time.Second, func() {
_ = ipv4Peer.SetReadDeadline(time.Now())
_ = ipv6Peer.SetReadDeadline(time.Now())
})
wg.Wait()
close(ipv4Received)
close(ipv6Received)
ipv4Count := 0
for pkt := range ipv4Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
ipv4Count++
}
ipv6Count := 0
for pkt := range ipv6Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
tests := []struct {
name string
network string
addr string
wantIPv4 bool
}{
{"IPv4 any", "udp4", "0.0.0.0:0", true},
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
{"IPv6 any", "udp6", "[::]:0", false},
{"IPv6 loopback", "udp6", "[::1]:0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
require.NoError(t, err)
conn, err := net.ListenUDP(tt.network, addr)
if err != nil {
t.Skipf("%s not available: %v", tt.network, err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
isIPv4 := localAddr.IP.To4() != nil
assert.Equal(t, tt.wantIPv4, isIPv4)
})
}
}
// helpers
func setupICEBind(t *testing.T) *ICEBind {
t.Helper()
transportNet, err := stdnet.NewNet()
require.NoError(t, err)
address := wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
t.Helper()
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
ipv4Conn.Close()
t.Skipf("IPv6 not available: %v", err)
}
return ipv4Conn, ipv6Conn
}
func createMsgPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, 1)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, 40)
}
return &msgs
},
}
}
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
t.Helper()
udpAddr, err := net.ResolveUDPAddr(network, addr)
require.NoError(t, err)
conn, err := net.ListenUDP(network, udpAddr)
require.NoError(t, err)
return conn
}

View File

@@ -3,22 +3,8 @@ package configurer
import ( import (
"net" "net"
"net/netip" "net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
// This is a shared helper used by both kernel and userspace configurers.
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
return wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
PublicKey: peerKey,
PresharedKey: &psk,
UpdateOnly: updateOnly,
}},
}
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet { func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes)) ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes { for i, prefix := range prefixes {

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -46,18 +48,6 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.configure(cfg)
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
@@ -289,7 +279,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
TxBytes: p.TransmitBytes, TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes, RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime, LastHandshake: p.LastHandshakeTime,
PresharedKey: [32]byte(p.PresharedKey), PresharedKey: p.PresharedKey != zeroKey,
} }
if p.Endpoint != nil { if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint peer.Endpoint = *p.Endpoint

View File

@@ -22,16 +22,17 @@ import (
) )
const ( const (
privateKey = "private_key" privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec" ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyTxBytes = "tx_bytes" ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyRxBytes = "rx_bytes" ipcKeyTxBytes = "tx_bytes"
allowedIP = "allowed_ip" ipcKeyRxBytes = "rx_bytes"
endpoint = "endpoint" allowedIP = "allowed_ip"
fwmark = "fwmark" endpoint = "endpoint"
listenPort = "listen_port" fwmark = "fwmark"
publicKey = "public_key" listenPort = "listen_port"
presharedKey = "preshared_key" publicKey = "public_key"
presharedKey = "preshared_key"
) )
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -71,18 +72,6 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.device.IpcSet(toWgUserspaceString(cfg))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
@@ -433,25 +422,13 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
hexKey := hex.EncodeToString(p.PublicKey[:]) hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey)) sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.Remove {
sb.WriteString("remove=true\n")
}
if p.UpdateOnly {
sb.WriteString("update_only=true\n")
}
if p.PresharedKey != nil { if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
} }
if p.Endpoint != nil { if p.Remove {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String())) sb.WriteString("remove=true")
}
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
} }
if p.ReplaceAllowedIPs { if p.ReplaceAllowedIPs {
@@ -461,6 +438,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
for _, aip := range p.AllowedIPs { for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String())) sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
} }
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
} }
return sb.String() return sb.String()
} }
@@ -558,7 +543,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue continue
} }
host, portStr, err := net.SplitHostPort(val) host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil { if err != nil {
log.Errorf("failed to parse endpoint: %v", err) log.Errorf("failed to parse endpoint: %v", err)
continue continue
@@ -614,9 +599,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue continue
} }
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" { if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
if pskKey, err := hexToWireguardKey(val); err == nil { currentPeer.PresharedKey = true
currentPeer.PresharedKey = [32]byte(pskKey)
}
} }
} }
} }

View File

@@ -12,7 +12,7 @@ type Peer struct {
TxBytes int64 TxBytes int64
RxBytes int64 RxBytes int64
LastHandshake time.Time LastHandshake time.Time
PresharedKey [32]byte PresharedKey bool
} }
type Stats struct { type Stats struct {

View File

@@ -3,7 +3,6 @@
package device package device
import ( import (
"fmt"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -20,12 +19,11 @@ import (
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct { type WGTunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu uint16 mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
// todo: review if we can eliminate the TunAdapter
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool disableDNS bool
@@ -34,19 +32,17 @@ type WGTunDevice struct {
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
renewableTun *RenewableTUN
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS, disableDNS: disableDNS,
renewableTun: NewRenewableTUN(),
} }
} }
@@ -69,17 +65,14 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err return nil, err
} }
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd) tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return nil, err return nil, err
} }
t.renewableTun.AddDevice(unmonitoredTUN)
t.name = name t.name = name
t.filteredDevice = newDeviceFilter(t.renewableTun) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
@@ -111,23 +104,6 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *WGTunDevice) RenewTun(fd int) error {
if t.device == nil {
return fmt.Errorf("device not initialized")
}
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to renew Android interface: %s", err)
return err
}
t.renewableTun.AddDevice(unmonitoredTUN)
return nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement // todo implement
return nil return nil

View File

@@ -1,7 +1,9 @@
//go:build ios
// +build ios
package device package device
import ( import (
"fmt"
"os" "os"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -43,31 +45,10 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
} }
} }
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
// This typically means the Swift code couldn't find the utun control socket.
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
func (t *TunDevice) Create() (WGConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface") log.Infof("create tun interface")
var tunDevice tun.Device dupTunFd, err := unix.Dup(t.tunFd)
var err error
// Validate the tunnel file descriptor.
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
// A value of 0 means the Swift code couldn't find the utun control socket
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
// tvOS SDK headers). This is a hard error - there's no viable fallback
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
if t.tunFd == 0 {
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
return nil, ErrInvalidTunnelFD
}
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
var dupTunFd int
dupTunFd, err = unix.Dup(t.tunFd)
if err != nil { if err != nil {
log.Errorf("Unable to dup tun fd: %v", err) log.Errorf("Unable to dup tun fd: %v", err)
return nil, err return nil, err
@@ -79,7 +60,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
_ = unix.Close(dupTunFd) _ = unix.Close(dupTunFd)
return nil, err return nil, err
} }
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil { if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err) log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd) _ = unix.Close(dupTunFd)

View File

@@ -2,13 +2,6 @@
package device package device
import "fmt"
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create() return t.create()
} }
func (t *TunNetstackDevice) RenewTun(fd int) error {
// Doesn't make sense in Android for Netstack.
return fmt.Errorf("this function has not been implemented in Netstack for Android")
}

View File

@@ -17,7 +17,6 @@ type WGConfigurer interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)

View File

@@ -1,309 +0,0 @@
//go:build android
package device
import (
"io"
"os"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
//
// It also redirects tun.Device's Events() to a separate goroutine
// and closes it when Close is called.
//
// The WaitGroup and CloseOnce fields are used to ensure that the
// goroutine is awaited and closed only once.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel (RenewableTUN's events channel).
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
}
type RenewableTUN struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
events chan tun.Event
closed atomic.Bool
}
func NewRenewableTUN() *RenewableTUN {
r := &RenewableTUN{
devices: make([]*closeAwareDevice, 0),
mu: sync.Mutex{},
events: make(chan tun.Event, 16),
}
r.cond = sync.NewCond(&r.mu)
return r
}
func (r *RenewableTUN) File() *os.File {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// Read reads from an underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries reading from the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
for {
dev := r.peekLast()
if dev == nil {
// wait until AddDevice() signals a new device via cond.Broadcast()
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
return 0, io.EOF
}
continue
}
n, err = dev.Read(bufs, sizes, offset)
if err == nil {
return n, nil
}
// swap in progress; retry on the newest instead of returning the error
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err // propagate non-swap error
}
}
// Write writes to underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries writing to the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}
func (r *RenewableTUN) MTU() (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
func (r *RenewableTUN) Name() (string, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// Events returns a channel that is fed events from the underlying tun.Device's events channel
// once it is added.
func (r *RenewableTUN) Events() <-chan tun.Event {
return r.events
}
func (r *RenewableTUN) Close() error {
// Attempts to set the RenewableTUN closed flag to true.
// If it's already true, returns immediately.
if !r.closed.CompareAndSwap(false, true) {
return nil // already closed: idempotent
}
r.mu.Lock()
devices := r.devices
r.devices = nil
r.cond.Broadcast()
r.mu.Unlock()
var lastErr error
log.Debugf("closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
log.Debugf("error closing a device: %v", err)
lastErr = err
}
}
close(r.events)
return lastErr
}
func (r *RenewableTUN) BatchSize() int {
return 1
}
func (r *RenewableTUN) AddDevice(device tun.Device) {
r.mu.Lock()
if r.closed.Load() {
r.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(r.devices) > 0 {
toClose = r.devices[len(r.devices)-1]
}
cad := newClosableDevice(device)
cad.redirectEvents(r.events)
r.devices = []*closeAwareDevice{cad}
r.cond.Broadcast()
r.mu.Unlock()
if toClose != nil {
if err := toClose.Close(); err != nil {
log.Debugf("error closing last device: %v", err)
}
}
}
func (r *RenewableTUN) waitForDevice() bool {
r.mu.Lock()
defer r.mu.Unlock()
for len(r.devices) == 0 && !r.closed.Load() {
r.cond.Wait()
}
return !r.closed.Load()
}
func (r *RenewableTUN) peekLast() *closeAwareDevice {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.devices) == 0 {
return nil
}
return r.devices[len(r.devices)-1]
}

View File

@@ -21,6 +21,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device Device() *wgdevice.Device
GetNet() *netstack.Net GetNet() *netstack.Net
RenewTun(fd int) error
GetICEBind() device.EndpointManager GetICEBind() device.EndpointManager
} }

View File

@@ -50,7 +50,6 @@ func ValidateMTU(mtu uint16) error {
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
GetProxyPort() uint16
Free() error Free() error
} }
@@ -81,12 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy() return w.wgProxyFactory.GetProxy()
} }
// GetProxyPort returns the proxy port used by the WireGuard proxy.
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
func (w *WGIface) GetProxyPort() uint16 {
return w.wgProxyFactory.GetProxyPort()
}
// GetBind returns the EndpointManager userspace bind mode. // GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager { func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock() w.mu.Lock()
@@ -304,19 +297,6 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats() return w.configurer.FullStats()
} }
// SetPresharedKey sets or updates the preshared key for a peer.
// If updateOnly is true, only updates existing peer; if false, creates or updates.
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
}
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime) timeout := time.NewTimer(maxWaitTime)

View File

@@ -24,7 +24,3 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile") return fmt.Errorf("this function has not implemented on non mobile")
} }
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on non-android")
}

View File

@@ -6,7 +6,6 @@ import (
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
// todo: review does this function really necessary or can we merge it with iOS
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -23,9 +22,3 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform") return fmt.Errorf("this function has not implemented on this platform")
} }
func (w *WGIface) RenewTun(fd int) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.RenewTun(fd)
}

View File

@@ -39,7 +39,3 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform") return fmt.Errorf("this function has not implemented on this platform")
} }
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on this platform")
}

View File

@@ -1,7 +1,6 @@
package iface package iface
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -10,13 +9,13 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/stdnet"
) )
// keep darwin compatibility // keep darwin compatibility
@@ -41,7 +40,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8" addr := "100.64.0.1/8"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -124,7 +123,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) { func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32" wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -167,7 +166,7 @@ func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32" wgIP := "10.99.99.2/32"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -212,7 +211,7 @@ func TestRecreation(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32" wgIP := "10.99.99.2/32"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -285,7 +284,7 @@ func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30" wgIP := "10.99.99.5/30"
wgPort := 33100 wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -340,7 +339,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) { func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30" wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -410,7 +409,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30" wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -472,7 +471,7 @@ func Test_ConnectPeers(t *testing.T) {
peer2wgPort := 33200 peer2wgPort := 33200
keepAlive := 1 * time.Second keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet(context.Background(), nil) newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -515,7 +514,7 @@ func Test_ConnectPeers(t *testing.T) {
guid = fmt.Sprintf("{%s}", uuid.New().String()) guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid) device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet(context.Background(), nil) newNet, err = stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1,7 +1,6 @@
package udpmux package udpmux
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -13,9 +12,8 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
) )
/* /*
@@ -201,7 +199,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
if len(networks) > 0 { if len(networks) > 0 {
if m.params.Net == nil { if m.params.Net == nil {
var err error var err error
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil { if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err) m.params.Logger.Errorf("failed to get create network: %v", err)
} }
} }

View File

@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
} }
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.wgCurrentUsed = ep p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal() p.pausedCond.Signal()
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
} }
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error { func (p *ProxyBind) CloseConn() error {
if p.cancel == nil { if p.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
@@ -212,16 +212,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil return &netipAddr, nil
} }
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}

View File

@@ -8,6 +8,8 @@ import (
"net" "net"
"sync" "sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -24,10 +26,13 @@ const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
proxyPort int
mtu uint16 mtu uint16
ebpfManager ebpfMgr.Manager ebpfManager ebpfMgr.Manager
@@ -35,8 +40,7 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex turnConnMutex sync.Mutex
lastUsedPort uint16 lastUsedPort uint16
rawConnIPv4 net.PacketConn rawConn net.PacketConn
rawConnIPv6 net.PacketConn
conn transport.UDPConn conn transport.UDPConn
ctx context.Context ctx context.Context
@@ -58,39 +62,23 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
// Listen load ebpf program and listen the proxy // Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error { func (p *WGEBPFProxy) Listen() error {
pl := portLookup{} pl := portLookup{}
proxyPort, err := pl.searchFreePort() wgPorxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.proxyPort = proxyPort
// Prepare IPv4 raw socket (required)
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
if err != nil { if err != nil {
return err return err
} }
// Prepare IPv6 raw socket (optional) p.rawConn, err = rawsocket.PrepareSenderRawSocket()
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
if err != nil { if err != nil {
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err) return err
} }
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort) err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil { if err != nil {
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
}
if p.rawConnIPv6 != nil {
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
}
}
return err return err
} }
addr := net.UDPAddr{ addr := net.UDPAddr{
Port: proxyPort, Port: wgPorxyPort,
IP: net.ParseIP(loopbackAddr), IP: net.ParseIP(loopbackAddr),
} }
@@ -106,7 +94,7 @@ func (p *WGEBPFProxy) Listen() error {
p.conn = conn p.conn = conn
go p.proxyToRemote() go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", proxyPort) log.Infof("local wg proxy listening on: %d", wgPorxyPort)
return nil return nil
} }
@@ -147,25 +135,12 @@ func (p *WGEBPFProxy) Free() error {
result = multierror.Append(result, err) result = multierror.Append(result, err)
} }
if p.rawConnIPv4 != nil { if err := p.rawConn.Close(); err != nil {
if err := p.rawConnIPv4.Close(); err != nil { result = multierror.Append(result, err)
result = multierror.Append(result, err)
}
}
if p.rawConnIPv6 != nil {
if err := p.rawConnIPv6.Close(); err != nil {
result = multierror.Append(result, err)
}
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
// GetProxyPort returns the proxy listening port.
func (p *WGEBPFProxy) GetProxyPort() uint16 {
return uint16(p.proxyPort)
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance. // From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
@@ -241,3 +216,34 @@ generatePort:
} }
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,89 +10,12 @@ import (
"net" "net"
"sync" "sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
@@ -101,10 +24,8 @@ type ProxyWrapper struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders wgEndpointCurrentUsedAddr *net.UDPAddr
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool paused bool
pausedCond *sync.Cond pausedCond *sync.Cond
@@ -120,32 +41,15 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr p.wgRelayedEndpointAddr = addr
p.headers = headers return err
p.rawConn = p.selectRawConn(headers)
return nil
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -164,8 +68,7 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.headerCurrentUsed = p.headers p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
@@ -188,32 +91,10 @@ func (p *ProxyWrapper) Pause() {
} }
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.headerCurrentUsed = header p.wgEndpointCurrentUsedAddr = endpoint
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal() p.pausedCond.Signal()
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
@@ -255,7 +136,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait() p.pausedCond.Wait()
} }
err = p.sendPkg(buf[:n], p.headerCurrentUsed) err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
@@ -281,29 +162,3 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
return n, nil return n, nil
} }
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -3,19 +3,12 @@
package wgproxy package wgproxy
import ( import (
"os"
"strconv"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
const (
envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
)
type KernelFactory struct { type KernelFactory struct {
wgPort int wgPort int
mtu uint16 mtu uint16
@@ -29,12 +22,6 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
mtu: mtu, mtu: mtu,
} }
if isEBPFDisabled() {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
return f
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Infof("WireGuard Proxy Factory will produce UDP proxy")
@@ -54,30 +41,9 @@ func (w *KernelFactory) GetProxy() Proxy {
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)
} }
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
func (w *KernelFactory) GetProxyPort() uint16 {
if w.ebpfProxy == nil {
return 0
}
return w.ebpfProxy.GetProxyPort()
}
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil { if w.ebpfProxy == nil {
return nil return nil
} }
return w.ebpfProxy.Free() return w.ebpfProxy.Free()
} }
func isEBPFDisabled() bool {
val := os.Getenv(envDisableEBPFWGProxy)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
return false
}
return disabled
}

View File

@@ -24,11 +24,6 @@ func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu) return proxyBind.NewProxyBind(w.bind, w.mtu)
} }
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
func (w *USPFactory) GetProxyPort() uint16 {
return 0
}
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {
return nil return nil
} }

View File

@@ -8,87 +8,43 @@ import (
"os" "os"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets func PrepareSenderRawSocket() (net.PacketConn, error) {
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET, true)
}
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET6, false)
}
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
// Create a raw socket. // Create a raw socket.
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW) fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err) return nil, fmt.Errorf("creating raw socket failed: %w", err)
} }
// Set the header include option on the socket to tell the kernel that headers are included in the packet. // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers. err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if isIPv4 { if err != nil {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1) return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
} else {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
}
} }
// Bind the socket to the "lo" interface. // Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil { if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("binding to lo interface failed: %w", err) return nil, fmt.Errorf("binding to lo interface failed: %w", err)
} }
// Set the fwmark on the socket. // Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd) err = nbnet.SetSocketOpt(fd)
if err != nil { if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting fwmark failed: %w", err) return nil, fmt.Errorf("setting fwmark failed: %w", err)
} }
// Convert the file descriptor to a PacketConn. // Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil { if file == nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("converting fd to file failed") return nil, fmt.Errorf("converting fd to file failed")
} }
packetConn, err := net.FilePacketConn(file) packetConn, err := net.FilePacketConn(file)
if err != nil { if err != nil {
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file: %v", closeErr)
}
return nil, fmt.Errorf("converting file to packet conn failed: %w", err) return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
} }
// Close the original file to release the FD (net.FilePacketConn duplicates it)
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
}
return packetConn, nil return packetConn, nil
} }

View File

@@ -1,353 +0,0 @@
//go:build linux && !android
package wgproxy
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
func compareUDPAddr(addr1, addr2 net.Addr) bool {
udpAddr1, ok1 := addr1.(*net.UDPAddr)
udpAddr2, ok2 := addr2.(*net.UDPAddr)
if !ok1 || !ok2 {
return addr1.String() == addr2.String()
}
// Compare IP and Port, ignoring zone
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
func TestRedirectAs_UDP_IPv6(t *testing.T) {
wgPort := 51853
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// testRedirectAs is a helper function that tests the RedirectAs functionality
// It verifies that:
// 1. Initial traffic from relay connection works
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
// 3. Multiple packets are correctly redirected with the new source address
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
t.Helper()
ctx := context.Background()
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
}
defer wgListener4.Close()
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
}
defer wgListener6.Close()
// Determine which listener to use based on the NetBird address IP version
// (this is where initial traffic will come from before RedirectAs is called)
var wgListener *net.UDPConn
if p2pEndpoint.IP.To4() == nil {
wgListener = wgListener6
} else {
wgListener = wgListener4
}
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // Random port
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
// Add TURN connection to proxy
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
// Start the proxy
proxy.Work()
// Phase 1: Test initial relay traffic
msgFromRelay := []byte("hello from relay")
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write to relay server: %v", err)
}
// Set read deadline to avoid hanging
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
buf := make([]byte, 1024)
n, _, err := wgListener4.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read from WireGuard listener: %v", err)
}
if n != len(msgFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
}
if string(buf[:n]) != string(msgFromRelay) {
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
}
// Phase 2: Redirect to p2p endpoint
proxy.RedirectAs(p2pEndpoint)
// Give the proxy a moment to process the redirect
time.Sleep(100 * time.Millisecond)
// Phase 3: Test redirected traffic
redirectedMessages := [][]byte{
[]byte("redirected message 1"),
[]byte("redirected message 2"),
[]byte("redirected message 3"),
}
for i, msg := range redirectedMessages {
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
}
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
}
// Verify message content
if string(buf[:n]) != string(msg) {
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
}
// Verify source address matches p2p endpoint (this is the key test)
// Use compareUDPAddr to ignore IPv6 zone IDs
if !compareUDPAddr(srcAddr, p2pEndpoint) {
t.Errorf("message %d: expected source address %s, got %s",
i+1, p2pEndpoint.String(), srcAddr.String())
}
}
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
ctx := context.Background()
// Create WireGuard listener
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create WireGuard listener: %v", err)
}
defer wgListener.Close()
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
proxy.Work()
// Test switching between multiple endpoints - using addresses in local subnet
endpoints := []*net.UDPAddr{
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
}
for i, endpoint := range endpoints {
proxy.RedirectAs(endpoint)
time.Sleep(100 * time.Millisecond)
msg := []byte("test message")
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
}
buf := make([]byte, 1024)
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
}
if !compareUDPAddr(srcAddr, endpoint) {
t.Errorf("endpoint %d: expected source %s, got %s",
i, endpoint.String(), srcAddr.String())
}
}
}

View File

@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
// the connection is complete, an error is returned. Once successfully // the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the // connected, any expiration of the context will not affect the
// connection. // connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{} dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {

View File

@@ -19,56 +19,37 @@ var (
FixLengths: true, FixLengths: true,
} }
localHostNetIPAddrV4 = &net.IPAddr{ localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
} }
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
) )
type SrcFaker struct { type SrcFaker struct {
srcAddr *net.UDPAddr srcAddr *net.UDPAddr
rawSocket net.PacketConn rawSocket net.PacketConn
ipH gopacket.SerializableLayer ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
} }
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
// Create only the raw socket for the address family we need rawSocket, err := rawsocket.PrepareSenderRawSocket()
var rawSocket net.PacketConn
var err error
var localHostAddr *net.IPAddr
if srcAddr.IP.To4() != nil {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
localHostAddr = localHostNetIPAddrV4
} else {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
localHostAddr = localHostNetIPAddrV6
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
ipH, udpH, err := prepareHeaders(dstPort, srcAddr) ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil { if err != nil {
if closeErr := rawSocket.Close(); closeErr != nil {
log.Warnf("failed to close raw socket: %v", closeErr)
}
return nil, err return nil, err
} }
f := &SrcFaker{ f := &SrcFaker{
srcAddr: srcAddr, srcAddr: srcAddr,
rawSocket: rawSocket, rawSocket: rawSocket,
ipH: ipH, ipH: ipH,
udpH: udpH, udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(), layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
} }
return f, nil return f, nil
@@ -91,7 +72,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil { if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err) return 0, fmt.Errorf("serialize layers: %w", err)
} }
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr) n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil { if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err) return 0, fmt.Errorf("write to raw conn: %w", err)
} }
@@ -99,40 +80,19 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
} }
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
var ipH gopacket.SerializableLayer ipH := &layers.IPv4{
var networkLayer gopacket.NetworkLayer DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
// Check if source IP is IPv4 or IPv6 Version: 4,
if srcAddr.IP.To4() != nil { TTL: 64,
// IPv4 Protocol: layers.IPProtocolUDP,
ipv4 := &layers.IPv4{
DstIP: localHostNetIPAddrV4.IP,
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: localHostNetIPAddrV6.IP,
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
} }
udpH := &layers.UDP{ udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port), SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
} }
err := udpH.SetNetworkLayerForChecksum(networkLayer) err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
} }

Some files were not shown because too many files have changed in this diff Show More