mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Compare commits
52 Commits
v0.59.13
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c5648bb7b | ||
|
|
b7e98acd1f | ||
|
|
433bc4ead9 | ||
|
|
011cc81678 | ||
|
|
537151e0f3 | ||
|
|
a9c28ef723 | ||
|
|
c29bb1a289 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
d2e48d4f5e | ||
|
|
27dd97c9c4 | ||
|
|
e87b4ace11 | ||
|
|
a232cf614c | ||
|
|
a293f760af | ||
|
|
10e9cf8c62 | ||
|
|
7193bd2da7 | ||
|
|
52948ccd61 | ||
|
|
4b77359042 | ||
|
|
387d43bcc1 | ||
|
|
e47d815dd2 | ||
|
|
cb83b7c0d3 | ||
|
|
ddcd182859 | ||
|
|
aca0398105 | ||
|
|
02200d790b | ||
|
|
f31bba87b4 | ||
|
|
7285fef0f0 | ||
|
|
20973063d8 | ||
|
|
ba2e9b6d88 | ||
|
|
131d7a3694 | ||
|
|
290fe2d8b9 | ||
|
|
7fb1a2fe31 | ||
|
|
32146e576d | ||
|
|
1311364397 | ||
|
|
68f56b797d | ||
|
|
3351b38434 | ||
|
|
05cbead39b | ||
|
|
60f4d5f9b0 | ||
|
|
4eeb2d8deb | ||
|
|
d71a82769c | ||
|
|
0d79301141 |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Running pre-push hook..."
|
||||||
|
if ! make lint; then
|
||||||
|
echo ""
|
||||||
|
echo "Hint: To push without verification, run:"
|
||||||
|
echo " git push --no-verify"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All checks passed!"
|
||||||
12
.github/workflows/check-license-dependencies.yml
vendored
12
.github/workflows/check-license-dependencies.yml
vendored
@@ -24,23 +24,25 @@ jobs:
|
|||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
FOUND_ISSUES=0
|
FOUND_ISSUES=0
|
||||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ ! -z "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
FOUND_ISSUES=1
|
FOUND_ISSUES=1
|
||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
|
||||||
echo ""
|
echo ""
|
||||||
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
|
|||||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,13 +15,14 @@ jobs:
|
|||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
release: "14.2"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y curl pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
curl -vLO "$GO_URL"
|
curl -vLO "$GO_URL"
|
||||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|||||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
@@ -106,15 +106,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -151,15 +151,15 @@ jobs:
|
|||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
id: go-env
|
id: go-env
|
||||||
run: |
|
run: |
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
-e CONTAINER=${CONTAINER} \
|
-e CONTAINER=${CONTAINER} \
|
||||||
golang:1.23-alpine \
|
golang:1.24-alpine \
|
||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
@@ -220,15 +220,15 @@ jobs:
|
|||||||
raceFlag: "-race"
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -270,15 +270,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -321,15 +321,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -408,15 +408,16 @@ jobs:
|
|||||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -497,15 +498,15 @@ jobs:
|
|||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -561,15 +562,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
uses: android-actions/setup-android@v3
|
uses: android-actions/setup-android@v3
|
||||||
with:
|
with:
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
- name: Setup NDK
|
- name: Setup NDK
|
||||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build android netbird lib
|
- name: build android netbird lib
|
||||||
@@ -56,9 +56,9 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build iOS netbird lib
|
- name: build iOS netbird lib
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -20,7 +20,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
@@ -40,7 +40,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
@@ -67,10 +67,13 @@ jobs:
|
|||||||
- name: Install curl
|
- name: Install curl
|
||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -80,9 +83,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup MySQL privileges
|
- name: Setup MySQL privileges
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: Install golangci-lint
|
- name: Install golangci-lint
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
env:
|
env:
|
||||||
@@ -60,8 +60,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 52428800 ]; then
|
if [ ${SIZE} -gt 57671680 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,14 @@ checked out and set up:
|
|||||||
go mod tidy
|
go mod tidy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
6. Configure Git hooks for automatic linting:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make setup-hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||||
|
|
||||||
### Dev Container Support
|
### Dev Container Support
|
||||||
|
|
||||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||||
|
|||||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
.PHONY: lint lint-all lint-install setup-hooks
|
||||||
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
|
# Install golangci-lint locally if needed
|
||||||
|
$(GOLANGCI_LINT):
|
||||||
|
@echo "Installing golangci-lint..."
|
||||||
|
@mkdir -p ./bin
|
||||||
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
|
# Lint only changed files (fast, for pre-push)
|
||||||
|
lint: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on changed files..."
|
||||||
|
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||||
|
|
||||||
|
# Lint entire codebase (slow, matches CI)
|
||||||
|
lint-all: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on all files..."
|
||||||
|
@$(GOLANGCI_LINT) run --timeout=12m
|
||||||
|
|
||||||
|
# Just install the linter
|
||||||
|
lint-install: $(GOLANGCI_LINT)
|
||||||
|
|
||||||
|
# Setup git hooks for all developers
|
||||||
|
setup-hooks:
|
||||||
|
@git config core.hooksPath .githooks
|
||||||
|
@chmod +x .githooks/pre-push
|
||||||
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
@@ -4,10 +4,13 @@ package android
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -16,10 +19,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/client/net"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -53,7 +59,6 @@ func init() {
|
|||||||
|
|
||||||
// Client struct manage the life circle of background service
|
// Client struct manage the life circle of background service
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cfgFile string
|
|
||||||
tunAdapter device.TunAdapter
|
tunAdapter device.TunAdapter
|
||||||
iFaceDiscover IFaceDiscover
|
iFaceDiscover IFaceDiscover
|
||||||
recorder *peer.Status
|
recorder *peer.Status
|
||||||
@@ -67,12 +72,11 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||||
execWorkaround(androidSDKVersion)
|
execWorkaround(androidSDKVersion)
|
||||||
|
|
||||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
uiVersion: uiVersion,
|
uiVersion: uiVersion,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
@@ -84,10 +88,16 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
|
|
||||||
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
|
||||||
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: cfgFile,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -107,23 +117,29 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
c.ctxCancelLock.Unlock()
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
auth := NewAuthWithConfig(ctx, cfg)
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
err = auth.login(urlOpener)
|
err = auth.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
// In this case make no sense handle registration steps.
|
// In this case make no sense handle registration steps.
|
||||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
|
|
||||||
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
|
||||||
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: cfgFile,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -141,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -156,6 +172,19 @@ func (c *Client) Stop() {
|
|||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) RenewTun(fd int) error {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
return fmt.Errorf("engine not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
e := c.connectClient.Engine()
|
||||||
|
if e == nil {
|
||||||
|
return fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -177,6 +206,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
p.ConnStatus.String(),
|
||||||
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
@@ -201,22 +231,32 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
log.Error("could not get route selector")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
networkArray := &NetworkArray{
|
networkArray := &NetworkArray{
|
||||||
items: make([]Network, 0),
|
items: make([]Network, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||||
|
|
||||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r := routes[0]
|
r := routes[0]
|
||||||
|
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||||
netStr := r.Network.String()
|
netStr := r.Network.String()
|
||||||
|
|
||||||
if r.IsDynamic() {
|
if r.IsDynamic() {
|
||||||
netStr = r.Domains.SafeString()
|
netStr = r.Domains.SafeString()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
continue
|
continue
|
||||||
@@ -224,8 +264,10 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
network := Network{
|
network := Network{
|
||||||
Name: string(id),
|
Name: string(id),
|
||||||
Network: netStr,
|
Network: netStr,
|
||||||
Peer: peer.FQDN,
|
Peer: routePeer.FQDN,
|
||||||
Status: peer.ConnStatus.String(),
|
Status: routePeer.ConnStatus.String(),
|
||||||
|
IsSelected: routeSelector.IsSelected(id),
|
||||||
|
Domains: domains,
|
||||||
}
|
}
|
||||||
networkArray.Add(network)
|
networkArray.Add(network)
|
||||||
}
|
}
|
||||||
@@ -253,6 +295,69 @@ func (c *Client) RemoveConnectionListener() {
|
|||||||
c.recorder.RemoveConnectionListener()
|
c.recorder.RemoveConnectionListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) toggleRoute(command routeCommand) error {
|
||||||
|
return command.toggleRoute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
|
client := c.connectClient
|
||||||
|
if client == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := client.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := engine.GetRouteManager()
|
||||||
|
if manager == nil {
|
||||||
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SelectRoute(route string) error {
|
||||||
|
manager, err := c.getRouteManager()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) DeselectRoute(route string) error {
|
||||||
|
manager, err := c.getRouteManager()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
|
||||||
|
// with its resolved IP addresses from the provided resolvedDomains map.
|
||||||
|
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
|
||||||
|
domains := NetworkDomains{}
|
||||||
|
|
||||||
|
for _, d := range route.Domains {
|
||||||
|
networkDomain := NetworkDomain{
|
||||||
|
Address: d.SafeString(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if info, exists := resolvedDomains[d]; exists {
|
||||||
|
for _, prefix := range info.Prefixes {
|
||||||
|
networkDomain.addResolvedIP(prefix.Addr().String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domains.Add(&networkDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
func exportEnvList(list *EnvList) {
|
func exportEnvList(list *EnvList) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
|||||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(url string, userCode string)
|
||||||
OnLoginSuccess()
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login try register the client on the server
|
// Login try register the client on the server
|
||||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||||
go func() {
|
go func() {
|
||||||
err := a.login(urlOpener)
|
err := a.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resultListener.OnError(err)
|
resultListener.OnError(err)
|
||||||
} else {
|
} else {
|
||||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||||
|
|||||||
56
client/android/network_domains.go
Normal file
56
client/android/network_domains.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type ResolvedIPs struct {
|
||||||
|
resolvedIPs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||||
|
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Get(i int) (string, error) {
|
||||||
|
if i < 0 || i >= len(r.resolvedIPs) {
|
||||||
|
return "", fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return r.resolvedIPs[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Size() int {
|
||||||
|
return len(r.resolvedIPs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkDomain struct {
|
||||||
|
Address string
|
||||||
|
resolvedIPs ResolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
|
||||||
|
d.resolvedIPs.Add(resolvedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
|
||||||
|
return &d.resolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkDomains struct {
|
||||||
|
domains []*NetworkDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Add(domain *NetworkDomain) {
|
||||||
|
n.domains = append(n.domains, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
|
||||||
|
if i < 0 || i >= len(n.domains) {
|
||||||
|
return nil, fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return n.domains[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Size() int {
|
||||||
|
return len(n.domains)
|
||||||
|
}
|
||||||
@@ -7,6 +7,12 @@ type Network struct {
|
|||||||
Network string
|
Network string
|
||||||
Peer string
|
Peer string
|
||||||
Status string
|
Status string
|
||||||
|
IsSelected bool
|
||||||
|
Domains NetworkDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Network) GetNetworkDomains() *NetworkDomains {
|
||||||
|
return &n.Domains
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkArray struct {
|
type NetworkArray struct {
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
@@ -5,6 +7,11 @@ type PeerInfo struct {
|
|||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
|
Routes PeerRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
|
||||||
|
return &p.Routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoArray is a wrapper of []PeerInfo
|
// PeerInfoArray is a wrapper of []PeerInfo
|
||||||
|
|||||||
20
client/android/peer_routes.go
Normal file
20
client/android/peer_routes.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type PeerRoutes struct {
|
||||||
|
routes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerRoutes) Get(i int) (string, error) {
|
||||||
|
if i < 0 || i >= len(p.routes) {
|
||||||
|
return "", fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return p.routes[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerRoutes) Size() int {
|
||||||
|
return len(p.routes)
|
||||||
|
}
|
||||||
10
client/android/platform_files.go
Normal file
10
client/android/platform_files.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
|
||||||
|
// at their default locations due to android OS restrictions.
|
||||||
|
type PlatformFiles interface {
|
||||||
|
ConfigurationFilePath() string
|
||||||
|
StateFilePath() string
|
||||||
|
}
|
||||||
@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
|||||||
p.configInput.ServerSSHAllowed = &allowed
|
p.configInput.ServerSSHAllowed = &allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHRoot reads SSH root login setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHRoot != nil {
|
||||||
|
return *p.configInput.EnableSSHRoot, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHRoot == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHRoot, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHRoot stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
|
||||||
|
p.configInput.EnableSSHRoot = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHSFTP reads SSH SFTP setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHSFTP != nil {
|
||||||
|
return *p.configInput.EnableSSHSFTP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHSFTP == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHSFTP, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHSFTP stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
|
||||||
|
p.configInput.EnableSSHSFTP = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHLocalPortForwarding != nil {
|
||||||
|
return *p.configInput.EnableSSHLocalPortForwarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHLocalPortForwarding == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHLocalPortForwarding, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
|
||||||
|
p.configInput.EnableSSHLocalPortForwarding = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHRemotePortForwarding != nil {
|
||||||
|
return *p.configInput.EnableSSHRemotePortForwarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHRemotePortForwarding == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHRemotePortForwarding, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
|
||||||
|
p.configInput.EnableSSHRemotePortForwarding = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
// GetBlockInbound reads block inbound setting from config file
|
// GetBlockInbound reads block inbound setting from config file
|
||||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||||
if p.configInput.BlockInbound != nil {
|
if p.configInput.BlockInbound != nil {
|
||||||
|
|||||||
257
client/android/profile_manager.go
Normal file
257
client/android/profile_manager.go
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Android-specific config filename (different from desktop default.json)
|
||||||
|
defaultConfigFilename = "netbird.cfg"
|
||||||
|
// Subdirectory for non-default profiles (must match Java Preferences.java)
|
||||||
|
profilesSubdir = "profiles"
|
||||||
|
// Android uses a single user context per app (non-empty username required by ServiceManager)
|
||||||
|
androidUsername = "android"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Profile represents a profile for gomobile
|
||||||
|
type Profile struct {
|
||||||
|
Name string
|
||||||
|
IsActive bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProfileArray wraps profiles for gomobile compatibility
|
||||||
|
type ProfileArray struct {
|
||||||
|
items []*Profile
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length returns the number of profiles
|
||||||
|
func (p *ProfileArray) Length() int {
|
||||||
|
return len(p.items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the profile at index i
|
||||||
|
func (p *ProfileArray) Get(i int) *Profile {
|
||||||
|
if i < 0 || i >= len(p.items) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return p.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
/data/data/io.netbird.client/files/ ← configDir parameter
|
||||||
|
├── netbird.cfg ← Default profile config
|
||||||
|
├── state.json ← Default profile state
|
||||||
|
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
|
||||||
|
└── profiles/ ← Subdirectory for non-default profiles
|
||||||
|
├── work.json ← Work profile config
|
||||||
|
├── work.state.json ← Work profile state
|
||||||
|
├── personal.json ← Personal profile config
|
||||||
|
└── personal.state.json ← Personal profile state
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ProfileManager manages profiles for Android
|
||||||
|
// It wraps the internal profilemanager to provide Android-specific behavior
|
||||||
|
type ProfileManager struct {
|
||||||
|
configDir string
|
||||||
|
serviceMgr *profilemanager.ServiceManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProfileManager creates a new profile manager for Android
|
||||||
|
func NewProfileManager(configDir string) *ProfileManager {
|
||||||
|
// Set the default config path for Android (stored in root configDir, not profiles/)
|
||||||
|
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
|
||||||
|
|
||||||
|
// Set global paths for Android
|
||||||
|
profilemanager.DefaultConfigPathDir = configDir
|
||||||
|
profilemanager.DefaultConfigPath = defaultConfigPath
|
||||||
|
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
|
||||||
|
|
||||||
|
// Create ServiceManager with profiles/ subdirectory
|
||||||
|
// This avoids modifying the global ConfigDirOverride for profile listing
|
||||||
|
profilesDir := filepath.Join(configDir, profilesSubdir)
|
||||||
|
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
|
||||||
|
|
||||||
|
return &ProfileManager{
|
||||||
|
configDir: configDir,
|
||||||
|
serviceMgr: serviceMgr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProfiles returns all available profiles
|
||||||
|
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||||
|
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
|
||||||
|
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list profiles: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert internal profiles to Android Profile type
|
||||||
|
var profiles []*Profile
|
||||||
|
for _, p := range internalProfiles {
|
||||||
|
profiles = append(profiles, &Profile{
|
||||||
|
Name: p.Name,
|
||||||
|
IsActive: p.IsActive,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ProfileArray{items: profiles}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveProfile returns the currently active profile name
|
||||||
|
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||||
|
// Use ServiceManager to stay consistent with ListProfiles
|
||||||
|
// ServiceManager uses active_profile.json
|
||||||
|
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||||
|
}
|
||||||
|
return activeState.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwitchProfile switches to a different profile
|
||||||
|
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||||
|
// Use ServiceManager to stay consistent with ListProfiles
|
||||||
|
// ServiceManager uses active_profile.json
|
||||||
|
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||||
|
Name: profileName,
|
||||||
|
Username: androidUsername,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to switch profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("switched to profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddProfile creates a new profile
|
||||||
|
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||||
|
// Use ServiceManager (creates profile in profiles/ directory)
|
||||||
|
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
|
||||||
|
return fmt.Errorf("failed to add profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("created new profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogoutProfile logs out from a profile (clears authentication)
|
||||||
|
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||||
|
profileName = sanitizeProfileName(profileName)
|
||||||
|
|
||||||
|
configPath, err := pm.getProfileConfigPath(profileName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if profile exists
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read current config using internal profilemanager
|
||||||
|
config, err := profilemanager.ReadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read profile config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear authentication by removing private key and SSH key
|
||||||
|
config.PrivateKey = ""
|
||||||
|
config.SSHKey = ""
|
||||||
|
|
||||||
|
// Save config using internal profilemanager
|
||||||
|
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
|
||||||
|
return fmt.Errorf("failed to save config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("logged out from profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveProfile deletes a profile
|
||||||
|
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||||
|
// Use ServiceManager (removes profile from profiles/ directory)
|
||||||
|
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("removed profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProfileConfigPath returns the config file path for a profile
|
||||||
|
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||||
|
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||||
|
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||||
|
// Android uses netbird.cfg for default profile instead of default.json
|
||||||
|
// Default profile is stored in root configDir, not in profiles/
|
||||||
|
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-default profiles are stored in profiles subdirectory
|
||||||
|
// This matches the Java Preferences.java expectation
|
||||||
|
profileName = sanitizeProfileName(profileName)
|
||||||
|
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||||
|
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigPath returns the config file path for a given profile
|
||||||
|
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||||
|
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||||
|
return pm.getProfileConfigPath(profileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStateFilePath returns the state file path for a given profile
|
||||||
|
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||||
|
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||||
|
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||||
|
return filepath.Join(pm.configDir, "state.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
profileName = sanitizeProfileName(profileName)
|
||||||
|
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||||
|
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||||
|
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
|
||||||
|
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||||
|
activeProfile, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||||
|
}
|
||||||
|
return pm.GetConfigPath(activeProfile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||||
|
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
|
||||||
|
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||||
|
activeProfile, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||||
|
}
|
||||||
|
return pm.GetStateFilePath(activeProfile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeProfileName removes invalid characters from profile name
|
||||||
|
func sanitizeProfileName(name string) string {
|
||||||
|
// Keep only alphanumeric, underscore, and hyphen
|
||||||
|
var result strings.Builder
|
||||||
|
for _, r := range name {
|
||||||
|
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||||
|
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
67
client/android/route_command.go
Normal file
67
client/android/route_command.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func executeRouteToggle(id string, manager routemanager.Manager,
|
||||||
|
operationName string,
|
||||||
|
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
|
||||||
|
netID := route.NetID(id)
|
||||||
|
routes := []route.NetID{netID}
|
||||||
|
|
||||||
|
log.Debugf("%s with id: %s", operationName, id)
|
||||||
|
|
||||||
|
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||||
|
log.Debugf("error when %s: %s", operationName, err)
|
||||||
|
return fmt.Errorf("error %s: %w", operationName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.TriggerSelection(manager.GetClientRoutes())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type routeCommand interface {
|
||||||
|
toggleRoute() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type selectRouteCommand struct {
|
||||||
|
route string
|
||||||
|
manager routemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s selectRouteCommand) toggleRoute() error {
|
||||||
|
routeSelector := s.manager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
return fmt.Errorf("no route selector available")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
|
return routeSelector.SelectRoutes(routes, true, allRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
|
||||||
|
}
|
||||||
|
|
||||||
|
type deselectRouteCommand struct {
|
||||||
|
route string
|
||||||
|
manager routemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d deselectRouteCommand) toggleRoute() error {
|
||||||
|
routeSelector := d.manager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
return fmt.Errorf("no route selector available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
|
||||||
|
}
|
||||||
136
client/cmd/kubeconfig.go
Normal file
136
client/cmd/kubeconfig.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kubeconfigOutput string
|
||||||
|
kubeconfigCluster string
|
||||||
|
kubeconfigContext string
|
||||||
|
kubeconfigUser string
|
||||||
|
kubeconfigServer string
|
||||||
|
kubeconfigNamespace string
|
||||||
|
)
|
||||||
|
|
||||||
|
var kubeconfigCmd = &cobra.Command{
|
||||||
|
Use: "kubeconfig",
|
||||||
|
Short: "Generate kubeconfig for accessing Kubernetes via NetBird",
|
||||||
|
Long: `Generate a kubeconfig file that points to a Kubernetes cluster accessible via NetBird.
|
||||||
|
|
||||||
|
The generated kubeconfig uses a dummy bearer token for authentication when the
|
||||||
|
cluster's auth proxy is running in 'auth' mode. The actual authentication is
|
||||||
|
handled by the NetBird network - the auth proxy identifies users by their
|
||||||
|
NetBird peer IP and impersonates them in the Kubernetes API.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
netbird kubeconfig --server https://k8s.example.netbird.cloud:6443 --cluster my-cluster
|
||||||
|
netbird kubeconfig --server https://10.100.0.1:6443 -o ~/.kube/netbird-config`,
|
||||||
|
RunE: kubeconfigFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
// init configures command-line flags for the kubeconfig command.
|
||||||
|
// It registers flags for output path, cluster, context, user, server, and namespace
|
||||||
|
// and marks the server flag as required.
|
||||||
|
func init() {
|
||||||
|
kubeconfigCmd.Flags().StringVarP(&kubeconfigOutput, "output", "o", "", "Output file path (default: stdout)")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigCluster, "cluster", "netbird-cluster", "Cluster name in kubeconfig")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigContext, "context", "netbird", "Context name in kubeconfig")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigUser, "user", "netbird-user", "User name in kubeconfig")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigServer, "server", "", "Kubernetes API server URL (required)")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigNamespace, "namespace", "default", "Default namespace")
|
||||||
|
_ = kubeconfigCmd.MarkFlagRequired("server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// kubeconfigFunc generates a kubeconfig file for accessing Kubernetes via the NetBird auth proxy.
|
||||||
|
// KUBECONFIG and running kubectl.
|
||||||
|
func kubeconfigFunc(cmd *cobra.Command, args []string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get current NetBird status to verify connection
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: Could not connect to NetBird daemon: %v\n", err)
|
||||||
|
cmd.PrintErrln("Generating kubeconfig anyway, but make sure NetBird is running before using it.")
|
||||||
|
} else {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{})
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: Could not get NetBird status: %v\n", status.Convert(err).Message())
|
||||||
|
} else if resp.Status != "Connected" {
|
||||||
|
cmd.PrintErrf("Warning: NetBird is not connected (status: %s)\n", resp.Status)
|
||||||
|
cmd.PrintErrln("Make sure to run 'netbird up' before using the generated kubeconfig.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kubeconfig := generateKubeconfig(kubeconfigServer, kubeconfigCluster, kubeconfigContext, kubeconfigUser, kubeconfigNamespace)
|
||||||
|
|
||||||
|
if kubeconfigOutput == "" {
|
||||||
|
fmt.Println(kubeconfig)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand ~ in path
|
||||||
|
if strings.HasPrefix(kubeconfigOutput, "~/") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
kubeconfigOutput = filepath.Join(home, kubeconfigOutput[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create directory if needed
|
||||||
|
dir := filepath.Dir(kubeconfigOutput)
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(kubeconfigOutput, []byte(kubeconfig), 0600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write kubeconfig: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Kubeconfig written to %s\n", kubeconfigOutput)
|
||||||
|
cmd.PrintErrln("\nWarning: TLS verification is disabled (insecure-skip-tls-verify: true).")
|
||||||
|
cmd.PrintErrln("This is safe when traffic is encrypted via NetBird's WireGuard tunnel.")
|
||||||
|
cmd.Printf("\nTo use this kubeconfig:\n")
|
||||||
|
cmd.Printf(" export KUBECONFIG=%s\n", kubeconfigOutput)
|
||||||
|
cmd.Printf(" kubectl get nodes\n")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKubeconfig creates a kubeconfig YAML string with the given parameters.
|
||||||
|
// generateKubeconfig generates a kubeconfig YAML for accessing the specified Kubernetes API server via NetBird.
|
||||||
|
// The returned config sets the current context to the provided context, includes the given cluster, user, and namespace,
|
||||||
|
// enables `insecure-skip-tls-verify: true`, and embeds the static token `netbird-auth-proxy`.
|
||||||
|
func generateKubeconfig(server, cluster, context, user, namespace string) string {
|
||||||
|
return fmt.Sprintf(`apiVersion: v1
|
||||||
|
kind: Config
|
||||||
|
clusters:
|
||||||
|
- cluster:
|
||||||
|
insecure-skip-tls-verify: true
|
||||||
|
server: %s
|
||||||
|
name: %s
|
||||||
|
contexts:
|
||||||
|
- context:
|
||||||
|
cluster: %s
|
||||||
|
namespace: %s
|
||||||
|
user: %s
|
||||||
|
name: %s
|
||||||
|
current-context: %s
|
||||||
|
users:
|
||||||
|
- name: %s
|
||||||
|
user:
|
||||||
|
token: netbird-auth-proxy
|
||||||
|
`, server, cluster, cluster, namespace, user, context, context, user)
|
||||||
|
}
|
||||||
@@ -4,14 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
hint = profileState.Email
|
hint = profileState.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
if err := openBrowser(verificationURIComplete); err != nil {
|
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
|
||||||
func openBrowser(url string) error {
|
|
||||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
|
||||||
return exec.Command(browser, url).Start()
|
|
||||||
}
|
|
||||||
return open.Run(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ const (
|
|||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
networkMonitorFlag = "network-monitor"
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||||
@@ -64,7 +63,6 @@ var (
|
|||||||
customDNSAddress string
|
customDNSAddress string
|
||||||
rosenpassEnabled bool
|
rosenpassEnabled bool
|
||||||
rosenpassPermissive bool
|
rosenpassPermissive bool
|
||||||
serverSSHAllowed bool
|
|
||||||
interfaceName string
|
interfaceName string
|
||||||
wireguardPort uint16
|
wireguardPort uint16
|
||||||
networkMonitor bool
|
networkMonitor bool
|
||||||
@@ -90,6 +88,13 @@ func Execute() error {
|
|||||||
return rootCmd.Execute()
|
return rootCmd.Execute()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init initializes package-level defaults and configures the root CLI command.
|
||||||
|
// It sets OS-specific default paths for configuration and logs, determines the default
|
||||||
|
// daemon address, registers persistent CLI flags (daemon address, management/admin URLs,
|
||||||
|
// logging, setup key options, pre-shared key, hostname, anonymization, and config path),
|
||||||
|
// and wires up all top-level and nested subcommands. It also defines upCmd-specific
|
||||||
|
// flags for external IP mapping, custom DNS resolver address, Rosenpass options,
|
||||||
|
// auto-connect control, and lazy connection.
|
||||||
func init() {
|
func init() {
|
||||||
defaultConfigPathDir = "/etc/netbird/"
|
defaultConfigPathDir = "/etc/netbird/"
|
||||||
defaultLogFileDir = "/var/log/netbird/"
|
defaultLogFileDir = "/var/log/netbird/"
|
||||||
@@ -143,6 +148,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
rootCmd.AddCommand(profileCmd)
|
rootCmd.AddCommand(profileCmd)
|
||||||
|
rootCmd.AddCommand(kubeconfigCmd)
|
||||||
|
|
||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
@@ -176,7 +182,6 @@ func init() {
|
|||||||
)
|
)
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||||
|
|
||||||
|
|||||||
176
client/cmd/signer/artifactkey.go
Normal file
176
client/cmd/signer/artifactkey.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bundlePubKeysRootPrivKeyFile string
|
||||||
|
bundlePubKeysPubKeyFiles []string
|
||||||
|
bundlePubKeysFile string
|
||||||
|
|
||||||
|
createArtifactKeyRootPrivKeyFile string
|
||||||
|
createArtifactKeyPrivKeyFile string
|
||||||
|
createArtifactKeyPubKeyFile string
|
||||||
|
createArtifactKeyExpiration time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
|
var createArtifactKeyCmd = &cobra.Command{
|
||||||
|
Use: "create-artifact-key",
|
||||||
|
Short: "Create a new artifact signing key",
|
||||||
|
Long: `Generate a new artifact signing key pair signed by the root private key.
|
||||||
|
The artifact key will be used to sign software artifacts/updates.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if createArtifactKeyExpiration <= 0 {
|
||||||
|
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
|
||||||
|
return fmt.Errorf("failed to create artifact key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var bundlePubKeysCmd = &cobra.Command{
|
||||||
|
Use: "bundle-pub-keys",
|
||||||
|
Short: "Bundle multiple artifact public keys into a signed package",
|
||||||
|
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
|
||||||
|
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if len(bundlePubKeysPubKeyFiles) == 0 {
|
||||||
|
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to bundle public keys: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(createArtifactKeyCmd)
|
||||||
|
|
||||||
|
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
|
||||||
|
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
|
||||||
|
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
|
||||||
|
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
|
||||||
|
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark expiration as required: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.AddCommand(bundlePubKeysCmd)
|
||||||
|
|
||||||
|
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
|
||||||
|
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
|
||||||
|
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
|
||||||
|
|
||||||
|
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
|
||||||
|
cmd.Println("Creating new artifact signing key...")
|
||||||
|
|
||||||
|
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read root private key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate artifact key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureFile := artifactPubKeyFile + ".sig"
|
||||||
|
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("✅ Artifact key created successfully.\n")
|
||||||
|
cmd.Printf("%s\n", artifactKey.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
|
||||||
|
cmd.Println("📦 Bundling public keys into signed package...")
|
||||||
|
|
||||||
|
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read root private key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
|
||||||
|
for _, pubFile := range artifactPubKeyFiles {
|
||||||
|
pubPem, err := os.ReadFile(pubFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read public key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pk, err := reposign.ParseArtifactPubKey(pubPem)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse artifact key: %w", err)
|
||||||
|
}
|
||||||
|
publicKeys = append(publicKeys, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bundle artifact keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureFile := bundlePubKeysFile + ".sig"
|
||||||
|
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
276
client/cmd/signer/artifactsign.go
Normal file
276
client/cmd/signer/artifactsign.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
signArtifactPrivKeyFile string
|
||||||
|
signArtifactArtifactFile string
|
||||||
|
|
||||||
|
verifyArtifactPubKeyFile string
|
||||||
|
verifyArtifactFile string
|
||||||
|
verifyArtifactSignatureFile string
|
||||||
|
|
||||||
|
verifyArtifactKeyPubKeyFile string
|
||||||
|
verifyArtifactKeyRootPubKeyFile string
|
||||||
|
verifyArtifactKeySignatureFile string
|
||||||
|
verifyArtifactKeyRevocationFile string
|
||||||
|
)
|
||||||
|
|
||||||
|
var signArtifactCmd = &cobra.Command{
|
||||||
|
Use: "sign-artifact",
|
||||||
|
Short: "Sign an artifact using an artifact private key",
|
||||||
|
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
|
||||||
|
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to sign artifact: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyArtifactCmd = &cobra.Command{
|
||||||
|
Use: "verify-artifact",
|
||||||
|
Short: "Verify an artifact signature using an artifact public key",
|
||||||
|
Long: `Verify a software artifact signature using the artifact's public key.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to verify artifact: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyArtifactKeyCmd = &cobra.Command{
|
||||||
|
Use: "verify-artifact-key",
|
||||||
|
Short: "Verify an artifact public key was signed by a root key",
|
||||||
|
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
|
||||||
|
This validates the chain of trust from the root key to the artifact key.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to verify artifact key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(signArtifactCmd)
|
||||||
|
rootCmd.AddCommand(verifyArtifactCmd)
|
||||||
|
rootCmd.AddCommand(verifyArtifactKeyCmd)
|
||||||
|
|
||||||
|
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
|
||||||
|
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
|
||||||
|
|
||||||
|
// artifact-file is required, but artifact-key-file can come from env var
|
||||||
|
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
|
||||||
|
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
|
||||||
|
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
|
||||||
|
|
||||||
|
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
|
||||||
|
|
||||||
|
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark root-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
|
||||||
|
cmd.Println("🖋️ Signing artifact...")
|
||||||
|
|
||||||
|
// Load private key from env var or file
|
||||||
|
var privKeyPEM []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
|
||||||
|
// Use key from environment variable
|
||||||
|
privKeyPEM = []byte(envKey)
|
||||||
|
} else if privKeyFile != "" {
|
||||||
|
// Fall back to file
|
||||||
|
privKeyPEM, err = os.ReadFile(privKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read private key file: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse artifact private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactData, err := os.ReadFile(artifactFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read artifact file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := reposign.SignData(privateKey, artifactData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sign artifact: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sigFile := artifactFile + ".sig"
|
||||||
|
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("✅ Artifact signed successfully.\n")
|
||||||
|
cmd.Printf("Signature file: %s\n", sigFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
|
||||||
|
cmd.Println("🔍 Verifying artifact...")
|
||||||
|
|
||||||
|
// Read artifact public key
|
||||||
|
pubKeyPEM, err := os.ReadFile(pubKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read public key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse artifact public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read artifact data
|
||||||
|
artifactData, err := os.ReadFile(artifactFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read artifact file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read signature
|
||||||
|
sigBytes, err := os.ReadFile(signatureFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read signature file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := reposign.ParseSignature(sigBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate artifact
|
||||||
|
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
|
||||||
|
return fmt.Errorf("artifact verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Artifact signature is valid")
|
||||||
|
cmd.Printf("Artifact: %s\n", artifactFile)
|
||||||
|
cmd.Printf("Signed by key: %s\n", signature.KeyID)
|
||||||
|
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
|
||||||
|
cmd.Println("🔍 Verifying artifact key...")
|
||||||
|
|
||||||
|
// Read artifact key data
|
||||||
|
artifactKeyData, err := os.ReadFile(artifactKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read artifact key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read root public key(s)
|
||||||
|
rootKeyData, err := os.ReadFile(rootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse root public key(s): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read signature
|
||||||
|
sigBytes, err := os.ReadFile(signatureFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read signature file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := reposign.ParseSignature(sigBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read optional revocation list
|
||||||
|
var revocationList *reposign.RevocationList
|
||||||
|
if revocationFile != "" {
|
||||||
|
revData, err := os.ReadFile(revocationFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read revocation file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
revocationList, err = reposign.ParseRevocationList(revData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate artifact key(s)
|
||||||
|
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("artifact key verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Artifact key(s) verified successfully")
|
||||||
|
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
|
||||||
|
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
|
||||||
|
for i, key := range validKeys {
|
||||||
|
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
|
||||||
|
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
if !key.Metadata.ExpiresAt.IsZero() {
|
||||||
|
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
} else {
|
||||||
|
cmd.Printf(" Expires: Never\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRootPublicKeys parses a root public key from PEM data
|
||||||
|
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
|
||||||
|
key, err := reposign.ParseRootPublicKey(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []reposign.PublicKey{key}, nil
|
||||||
|
}
|
||||||
21
client/cmd/signer/main.go
Normal file
21
client/cmd/signer/main.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "signer",
|
||||||
|
Short: "A CLI tool for managing cryptographic keys and artifacts",
|
||||||
|
Long: `signer is a command-line tool that helps you manage
|
||||||
|
root keys, artifact keys, and revocation lists securely.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
rootCmd.Println(err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
220
client/cmd/signer/revocation.go
Normal file
220
client/cmd/signer/revocation.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
keyID string
|
||||||
|
revocationListFile string
|
||||||
|
privateRootKeyFile string
|
||||||
|
publicRootKeyFile string
|
||||||
|
signatureFile string
|
||||||
|
expirationDuration time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
|
var createRevocationListCmd = &cobra.Command{
|
||||||
|
Use: "create-revocation-list",
|
||||||
|
Short: "Create a new revocation list signed by the private root key",
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var extendRevocationListCmd = &cobra.Command{
|
||||||
|
Use: "extend-revocation-list",
|
||||||
|
Short: "Extend an existing revocation list with a given key ID",
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyRevocationListCmd = &cobra.Command{
|
||||||
|
Use: "verify-revocation-list",
|
||||||
|
Short: "Verify a revocation list signature using the public root key",
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(createRevocationListCmd)
|
||||||
|
rootCmd.AddCommand(extendRevocationListCmd)
|
||||||
|
rootCmd.AddCommand(verifyRevocationListCmd)
|
||||||
|
|
||||||
|
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||||
|
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||||
|
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||||
|
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
|
||||||
|
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||||
|
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||||
|
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||||
|
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
|
||||||
|
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
|
||||||
|
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
|
||||||
|
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
|
||||||
|
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
|
||||||
|
return fmt.Errorf("failed to write output files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Revocation list created successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
|
||||||
|
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rlBytes, err := os.ReadFile(revocationListFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rl, err := reposign.ParseRevocationList(rlBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kid, err := reposign.ParseKeyID(keyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid key ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extend revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
|
||||||
|
return fmt.Errorf("failed to write output files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Revocation list extended successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
|
||||||
|
// Read revocation list file
|
||||||
|
rlBytes, err := os.ReadFile(revocationListFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read signature file
|
||||||
|
sigBytes, err := os.ReadFile(signatureFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read signature file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read public root key file
|
||||||
|
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read public root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse public root key
|
||||||
|
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse public root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse signature
|
||||||
|
signature, err := reposign.ParseSignature(sigBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate revocation list
|
||||||
|
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to validate revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display results
|
||||||
|
cmd.Println("✅ Revocation list signature is valid")
|
||||||
|
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
|
||||||
|
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
|
||||||
|
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
|
||||||
|
|
||||||
|
if len(rl.Revoked) > 0 {
|
||||||
|
cmd.Println("\nRevoked Keys:")
|
||||||
|
for keyID, revokedTime := range rl.Revoked {
|
||||||
|
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
|
||||||
|
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write revocation list file: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write signature file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
74
client/cmd/signer/rootkey.go
Normal file
74
client/cmd/signer/rootkey.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
privKeyFile string
|
||||||
|
pubKeyFile string
|
||||||
|
rootExpiration time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
|
var createRootKeyCmd = &cobra.Command{
|
||||||
|
Use: "create-root-key",
|
||||||
|
Short: "Create a new root key pair",
|
||||||
|
Long: `Create a new root key pair and specify an expiration time for it.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Validate expiration
|
||||||
|
if rootExpiration <= 0 {
|
||||||
|
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run main logic
|
||||||
|
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
|
||||||
|
return fmt.Errorf("failed to generate root key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(createRootKeyCmd)
|
||||||
|
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
|
||||||
|
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
|
||||||
|
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
|
||||||
|
|
||||||
|
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
|
||||||
|
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write private key
|
||||||
|
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write public key
|
||||||
|
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("%s\n\n", rk.String())
|
||||||
|
cmd.Printf("✅ Root key pair generated successfully.\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -3,86 +3,162 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"os/user"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sshUsernameDesc = "SSH username"
|
||||||
|
hostArgumentRequired = "host argument required"
|
||||||
|
|
||||||
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
|
enableSSHRootFlag = "enable-ssh-root"
|
||||||
|
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||||
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
|
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
port int
|
port int
|
||||||
userName = "root"
|
username string
|
||||||
host string
|
host string
|
||||||
|
command string
|
||||||
|
localForwards []string
|
||||||
|
remoteForwards []string
|
||||||
|
strictHostKeyChecking bool
|
||||||
|
knownHostsFile string
|
||||||
|
identityFile string
|
||||||
|
skipCachedToken bool
|
||||||
|
requestPTY bool
|
||||||
|
sshNoBrowser bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serverSSHAllowed bool
|
||||||
|
enableSSHRoot bool
|
||||||
|
enableSSHSFTP bool
|
||||||
|
enableSSHLocalPortForward bool
|
||||||
|
enableSSHRemotePortForward bool
|
||||||
|
disableSSHAuth bool
|
||||||
|
sshJWTCacheTTL int
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
|
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||||
|
|
||||||
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
|
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||||
|
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||||
|
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||||
|
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
|
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||||
|
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||||
|
|
||||||
|
sshCmd.AddCommand(sshSftpCmd)
|
||||||
|
sshCmd.AddCommand(sshProxyCmd)
|
||||||
|
sshCmd.AddCommand(sshDetectCmd)
|
||||||
|
}
|
||||||
|
|
||||||
var sshCmd = &cobra.Command{
|
var sshCmd = &cobra.Command{
|
||||||
Use: "ssh [user@]host",
|
Use: "ssh [flags] [user@]host [command]",
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Short: "Connect to a NetBird peer via SSH",
|
||||||
if len(args) < 1 {
|
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||||
return errors.New("requires a host argument")
|
|
||||||
|
Port Forwarding:
|
||||||
|
-L [bind_address:]port:host:hostport Local port forwarding
|
||||||
|
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||||
|
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||||
|
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||||
|
|
||||||
|
SSH Options:
|
||||||
|
-p, --port int Remote SSH port (default 22)
|
||||||
|
-u, --user string SSH username
|
||||||
|
--login string SSH username (alias for --user)
|
||||||
|
-t, --tty Force pseudo-terminal allocation
|
||||||
|
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||||
|
-o, --known-hosts string Path to known_hosts file
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird ssh peer-hostname
|
||||||
|
netbird ssh root@peer-hostname
|
||||||
|
netbird ssh --login root peer-hostname
|
||||||
|
netbird ssh peer-hostname ls -la
|
||||||
|
netbird ssh peer-hostname whoami
|
||||||
|
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||||
|
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||||
|
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||||
|
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||||
|
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||||
|
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||||
|
DisableFlagParsing: true,
|
||||||
|
Args: validateSSHArgsWithoutFlagParsing,
|
||||||
|
RunE: sshFn,
|
||||||
|
Aliases: []string{"ssh"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshFn(cmd *cobra.Command, args []string) error {
|
||||||
|
for _, arg := range args {
|
||||||
|
if arg == "-h" || arg == "--help" {
|
||||||
|
return cmd.Help()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
split := strings.Split(args[0], "@")
|
|
||||||
if len(split) == 2 {
|
|
||||||
userName = split[0]
|
|
||||||
host = split[1]
|
|
||||||
} else {
|
|
||||||
host = args[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Short: "Connect to a remote SSH server",
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
SetFlagsFromEnvVars(cmd)
|
SetFlagsFromEnvVars(cmd)
|
||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
err := util.InitLog(logLevel, util.LogConsole)
|
logOutput := "console"
|
||||||
if err != nil {
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
return fmt.Errorf("failed initializing log %v", err)
|
logOutput = firstLogFile
|
||||||
}
|
}
|
||||||
|
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||||
if !util.IsAdmin() {
|
return fmt.Errorf("init log: %w", err)
|
||||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager(configPath)
|
|
||||||
activeProf, err := sm.GetActiveProfileState()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get active profile: %v", err)
|
|
||||||
}
|
|
||||||
profPath, err := activeProf.FilePath()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get active profile path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := profilemanager.ReadConfig(profPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read profile config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
sig := make(chan os.Signal, 1)
|
||||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
sshctx, cancel := context.WithCancel(ctx)
|
sshctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
// blocking
|
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
errCh <- err
|
||||||
cmd.Printf("Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -90,38 +166,686 @@ var sshCmd = &cobra.Command{
|
|||||||
select {
|
select {
|
||||||
case <-sig:
|
case <-sig:
|
||||||
cancel()
|
cancel()
|
||||||
|
<-sshctx.Done()
|
||||||
|
return nil
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
case <-sshctx.Done():
|
case <-sshctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||||
if err != nil {
|
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||||
cmd.Printf("Error: %v\n", err)
|
return envValue
|
||||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
|
||||||
"\nYou can verify the connection by running:\n\n" +
|
|
||||||
" netbird status\n\n")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
go func() {
|
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||||
<-ctx.Done()
|
return envValue
|
||||||
err = c.Close()
|
}
|
||||||
if err != nil {
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||||
|
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||||
|
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetSSHGlobals sets SSH globals to their default values
|
||||||
|
func resetSSHGlobals() {
|
||||||
|
port = sshserver.DefaultSSHPort
|
||||||
|
username = ""
|
||||||
|
host = ""
|
||||||
|
command = ""
|
||||||
|
localForwards = nil
|
||||||
|
remoteForwards = nil
|
||||||
|
strictHostKeyChecking = true
|
||||||
|
knownHostsFile = ""
|
||||||
|
identityFile = ""
|
||||||
|
sshNoBrowser = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||||
|
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||||
|
var localForwardFlags []string
|
||||||
|
var remoteForwardFlags []string
|
||||||
|
var filteredArgs []string
|
||||||
|
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
arg := args[i]
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(arg, "-L"):
|
||||||
|
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
|
||||||
|
case strings.HasPrefix(arg, "-R"):
|
||||||
|
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
|
||||||
|
default:
|
||||||
|
filteredArgs = append(filteredArgs, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
|
||||||
|
if arg == "-L" || arg == "-R" {
|
||||||
|
if i+1 < len(args) {
|
||||||
|
flags = append(flags, args[i+1])
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
} else if len(arg) > 2 {
|
||||||
|
flags = append(flags, arg[2:])
|
||||||
|
}
|
||||||
|
return flags, i
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||||
|
func extractGlobalFlags(args []string) {
|
||||||
|
sshPos := findSSHCommandPosition(args)
|
||||||
|
if sshPos == -1 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
err = c.OpenTerminal()
|
globalArgs := args[:sshPos]
|
||||||
if err != nil {
|
parseGlobalArgs(globalArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||||
|
func findSSHCommandPosition(args []string) int {
|
||||||
|
for i, arg := range args {
|
||||||
|
if arg == "ssh" {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
configFlag = "config"
|
||||||
|
logLevelFlag = "log-level"
|
||||||
|
logFileFlag = "log-file"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||||
|
func parseGlobalArgs(globalArgs []string) {
|
||||||
|
flagHandlers := map[string]func(string){
|
||||||
|
configFlag: func(value string) { configPath = value },
|
||||||
|
logLevelFlag: func(value string) { logLevel = value },
|
||||||
|
logFileFlag: func(value string) {
|
||||||
|
if !slices.Contains(logFiles, value) {
|
||||||
|
logFiles = append(logFiles, value)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shortFlags := map[string]string{
|
||||||
|
"c": configFlag,
|
||||||
|
"l": logLevelFlag,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(globalArgs); i++ {
|
||||||
|
arg := globalArgs[i]
|
||||||
|
|
||||||
|
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||||
|
i = nextIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFlag handles generic flag parsing for both long and short forms
|
||||||
|
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||||
|
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||||
|
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||||
|
return true, currentIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||||
|
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||||
|
return true, currentIndex + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, currentIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
type parsedFlag struct {
|
||||||
|
flagName string
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||||
|
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||||
|
if !strings.Contains(arg, "=") {
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(arg, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(parts[0], "--") {
|
||||||
|
flagName := strings.TrimPrefix(parts[0], "--")
|
||||||
|
if _, exists := flagHandlers[flagName]; exists {
|
||||||
|
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||||
|
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||||
|
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||||
|
if _, exists := flagHandlers[longFlag]; exists {
|
||||||
|
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSpacedFormat handles --flag value and -f value formats
|
||||||
|
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||||
|
if currentIndex+1 >= len(args) {
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(arg, "--") {
|
||||||
|
flagName := strings.TrimPrefix(arg, "--")
|
||||||
|
if _, exists := flagHandlers[flagName]; exists {
|
||||||
|
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||||
|
shortFlag := strings.TrimPrefix(arg, "-")
|
||||||
|
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||||
|
if _, exists := flagHandlers[longFlag]; exists {
|
||||||
|
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||||
|
// sshFlags contains all SSH-related flags and parameters
|
||||||
|
type sshFlags struct {
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Login string
|
||||||
|
RequestPTY bool
|
||||||
|
StrictHostKeyChecking bool
|
||||||
|
KnownHostsFile string
|
||||||
|
IdentityFile string
|
||||||
|
SkipCachedToken bool
|
||||||
|
NoBrowser bool
|
||||||
|
ConfigPath string
|
||||||
|
LogLevel string
|
||||||
|
LocalForwards []string
|
||||||
|
RemoteForwards []string
|
||||||
|
Host string
|
||||||
|
Command string
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||||
|
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||||
|
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
|
||||||
|
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||||
|
fs.SetOutput(nil)
|
||||||
|
|
||||||
|
flags := &sshFlags{}
|
||||||
|
|
||||||
|
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||||
|
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||||
|
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||||
|
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||||
|
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||||
|
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||||
|
|
||||||
|
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||||
|
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||||
|
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||||
|
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||||
|
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||||
|
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||||
|
|
||||||
|
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||||
|
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||||
|
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||||
|
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||||
|
|
||||||
|
return fs, flags
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return errors.New(hostArgumentRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
resetSSHGlobals()
|
||||||
|
|
||||||
|
if len(os.Args) > 2 {
|
||||||
|
extractGlobalFlags(os.Args[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||||
|
|
||||||
|
fs, flags := createSSHFlagSet()
|
||||||
|
|
||||||
|
if err := fs.Parse(filteredArgs); err != nil {
|
||||||
|
if errors.Is(err, flag.ErrHelp) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remaining := fs.Args()
|
||||||
|
if len(remaining) < 1 {
|
||||||
|
return errors.New(hostArgumentRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
port = flags.Port
|
||||||
|
if flags.Username != "" {
|
||||||
|
username = flags.Username
|
||||||
|
} else if flags.Login != "" {
|
||||||
|
username = flags.Login
|
||||||
|
}
|
||||||
|
|
||||||
|
requestPTY = flags.RequestPTY
|
||||||
|
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||||
|
knownHostsFile = flags.KnownHostsFile
|
||||||
|
identityFile = flags.IdentityFile
|
||||||
|
skipCachedToken = flags.SkipCachedToken
|
||||||
|
sshNoBrowser = flags.NoBrowser
|
||||||
|
|
||||||
|
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||||
|
configPath = flags.ConfigPath
|
||||||
|
}
|
||||||
|
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||||
|
logLevel = flags.LogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
localForwards = localForwardFlags
|
||||||
|
remoteForwards = remoteForwardFlags
|
||||||
|
|
||||||
|
return parseHostnameAndCommand(remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHostnameAndCommand(args []string) error {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return errors.New(hostArgumentRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
arg := args[0]
|
||||||
|
if strings.Contains(arg, "@") {
|
||||||
|
parts := strings.SplitN(arg, "@", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return errors.New("invalid user@host format")
|
||||||
|
}
|
||||||
|
if username == "" {
|
||||||
|
username = parts[0]
|
||||||
|
}
|
||||||
|
host = parts[1]
|
||||||
|
} else {
|
||||||
|
host = arg
|
||||||
|
}
|
||||||
|
|
||||||
|
if username == "" {
|
||||||
|
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||||
|
username = sudoUser
|
||||||
|
} else if currentUser, err := user.Current(); err == nil {
|
||||||
|
username = currentUser.Username
|
||||||
|
} else {
|
||||||
|
username = "root"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everything after hostname becomes the command
|
||||||
|
if len(args) > 1 {
|
||||||
|
command = strings.Join(args[1:], " ")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
target := fmt.Sprintf("%s:%d", addr, port)
|
||||||
|
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||||
|
KnownHostsFile: knownHostsFile,
|
||||||
|
IdentityFile: identityFile,
|
||||||
|
DaemonAddr: daemonAddr,
|
||||||
|
SkipCachedToken: skipCachedToken,
|
||||||
|
InsecureSkipVerify: !strictHostKeyChecking,
|
||||||
|
NoBrowser: sshNoBrowser,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||||
|
cmd.Printf("\nTroubleshooting steps:\n")
|
||||||
|
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||||
|
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||||
|
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||||
|
return fmt.Errorf("dial %s: %w", target, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-sshCtx.Done()
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||||
|
return fmt.Errorf("start port forwarding: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if command != "" {
|
||||||
|
return executeSSHCommand(sshCtx, c, command)
|
||||||
|
}
|
||||||
|
return openSSHTerminal(sshCtx, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeSSHCommand executes a command over SSH.
|
||||||
|
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||||
|
var err error
|
||||||
|
if requestPTY {
|
||||||
|
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||||
|
} else {
|
||||||
|
err = c.ExecuteCommandWithIO(ctx, command)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitErr *ssh.ExitError
|
||||||
|
if errors.As(err, &exitErr) {
|
||||||
|
os.Exit(exitErr.ExitStatus())
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitMissingErr *ssh.ExitMissingError
|
||||||
|
if errors.As(err, &exitMissingErr) {
|
||||||
|
log.Debugf("Remote command exited without exit status: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("execute command: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openSSHTerminal opens an interactive SSH terminal.
|
||||||
|
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||||
|
if err := c.OpenTerminal(ctx); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitMissingErr *ssh.ExitMissingError
|
||||||
|
if errors.As(err, &exitMissingErr) {
|
||||||
|
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("open terminal: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||||
|
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||||
|
for _, forward := range localForwards {
|
||||||
|
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||||
|
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, forward := range remoteForwards {
|
||||||
|
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||||
|
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||||
|
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||||
|
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
cmd.Printf("Local port forward error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||||
|
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||||
|
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
cmd.Printf("Remote port forward error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||||
|
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||||
|
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||||
|
// Support formats:
|
||||||
|
// port:host:hostport -> localhost:port -> host:hostport
|
||||||
|
// host:port:host:hostport -> host:port -> host:hostport
|
||||||
|
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||||
|
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||||
|
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||||
|
|
||||||
|
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||||
|
return parseIPv6ForwardSpec(spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(spec, ":")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(parts) {
|
||||||
|
case 2:
|
||||||
|
return parseTwoPartForwardSpec(parts, spec)
|
||||||
|
case 3:
|
||||||
|
return parseThreePartForwardSpec(parts)
|
||||||
|
case 4:
|
||||||
|
return parseFourPartForwardSpec(parts)
|
||||||
|
default:
|
||||||
|
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||||
|
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||||
|
if isUnixSocket(parts[1]) {
|
||||||
|
localAddr := "localhost:" + parts[0]
|
||||||
|
remoteAddr := parts[1]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||||
|
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||||
|
if isUnixSocket(parts[2]) {
|
||||||
|
localHost := normalizeLocalHost(parts[0])
|
||||||
|
localAddr := localHost + ":" + parts[1]
|
||||||
|
remoteAddr := parts[2]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
localAddr := "localhost:" + parts[0]
|
||||||
|
remoteAddr := parts[1] + ":" + parts[2]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||||
|
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||||
|
localHost := normalizeLocalHost(parts[0])
|
||||||
|
localAddr := localHost + ":" + parts[1]
|
||||||
|
remoteAddr := parts[2] + ":" + parts[3]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||||
|
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||||
|
idx := strings.Index(spec, "]:")
|
||||||
|
if idx == -1 {
|
||||||
|
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv6Host := spec[:idx+1]
|
||||||
|
remaining := spec[idx+2:]
|
||||||
|
|
||||||
|
parts := strings.Split(remaining, ":")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := ipv6Host + ":" + parts[0]
|
||||||
|
remoteAddr := parts[1] + ":" + parts[2]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUnixSocket checks if a path is a Unix socket path.
|
||||||
|
func isUnixSocket(path string) bool {
|
||||||
|
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||||
|
func normalizeLocalHost(host string) string {
|
||||||
|
if host == "*" {
|
||||||
|
return "0.0.0.0"
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshProxyCmd = &cobra.Command{
|
||||||
|
Use: "proxy <host> <port>",
|
||||||
|
Short: "Internal SSH proxy for native SSH client integration",
|
||||||
|
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: sshProxyFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||||
|
logOutput := "console"
|
||||||
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
|
logOutput = firstLogFile
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0]
|
||||||
|
portStr := args[1]
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port: %s", portStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||||
|
// where command-line flags cannot be passed. Default is to open browser.
|
||||||
|
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !noBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create SSH proxy: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.Close(); err != nil {
|
||||||
|
log.Debugf("close SSH proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||||
|
return fmt.Errorf("SSH proxy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshDetectCmd = &cobra.Command{
|
||||||
|
Use: "detect <host> <port>",
|
||||||
|
Short: "Detect if a host is running NetBird SSH",
|
||||||
|
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: sshDetectFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||||
|
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
if err := util.InitLog(detectLogLevel, "console"); err != nil {
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0]
|
||||||
|
portStr := args[1]
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid port %q: %v", portStr, err)
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
|
||||||
|
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("SSH server detection failed: %v", err)
|
||||||
|
cancel()
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
os.Exit(serverType.ExitCode())
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sshExecUID uint32
|
||||||
|
sshExecGID uint32
|
||||||
|
sshExecGroups []uint
|
||||||
|
sshExecWorkingDir string
|
||||||
|
sshExecShell string
|
||||||
|
sshExecCommand string
|
||||||
|
sshExecPTY bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||||
|
var sshExecCmd = &cobra.Command{
|
||||||
|
Use: "exec",
|
||||||
|
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: runSSHExec,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||||
|
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||||
|
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||||
|
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||||
|
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||||
|
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||||
|
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||||
|
|
||||||
|
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshCmd.AddCommand(sshExecCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSSHExec handles the SSH exec subcommand execution.
|
||||||
|
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||||
|
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||||
|
|
||||||
|
var groups []uint32
|
||||||
|
for _, groupInt := range sshExecGroups {
|
||||||
|
groups = append(groups, uint32(groupInt))
|
||||||
|
}
|
||||||
|
|
||||||
|
config := sshserver.ExecutorConfig{
|
||||||
|
UID: sshExecUID,
|
||||||
|
GID: sshExecGID,
|
||||||
|
Groups: groups,
|
||||||
|
WorkingDir: sshExecWorkingDir,
|
||||||
|
Shell: sshExecShell,
|
||||||
|
Command: sshExecCommand,
|
||||||
|
PTY: sshExecPTY,
|
||||||
|
}
|
||||||
|
|
||||||
|
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sftpUID uint32
|
||||||
|
sftpGID uint32
|
||||||
|
sftpGroupsInt []uint
|
||||||
|
sftpWorkingDir string
|
||||||
|
)
|
||||||
|
|
||||||
|
var sshSftpCmd = &cobra.Command{
|
||||||
|
Use: "sftp",
|
||||||
|
Short: "SFTP server with privilege dropping (internal use)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: sftpMain,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||||
|
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||||
|
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||||
|
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||||
|
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||||
|
|
||||||
|
var groups []uint32
|
||||||
|
for _, groupInt := range sftpGroupsInt {
|
||||||
|
groups = append(groups, uint32(groupInt))
|
||||||
|
}
|
||||||
|
|
||||||
|
config := sshserver.ExecutorConfig{
|
||||||
|
UID: sftpUID,
|
||||||
|
GID: sftpGID,
|
||||||
|
Groups: groups,
|
||||||
|
WorkingDir: sftpWorkingDir,
|
||||||
|
Shell: "",
|
||||||
|
Command: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||||
|
|
||||||
|
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||||
|
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.WorkingDir != "" {
|
||||||
|
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||||
|
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sftpServer, err := sftp.NewServer(struct {
|
||||||
|
io.Reader
|
||||||
|
io.WriteCloser
|
||||||
|
}{
|
||||||
|
Reader: os.Stdin,
|
||||||
|
WriteCloser: os.Stdout,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("starting SFTP server with dropped privileges")
|
||||||
|
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||||
|
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||||
|
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||||
|
}
|
||||||
|
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||||
|
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||||
|
}
|
||||||
|
os.Exit(sshserver.ExitCodeSuccess)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
94
client/cmd/ssh_sftp_windows.go
Normal file
94
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sftpWorkingDir string
|
||||||
|
windowsUsername string
|
||||||
|
windowsDomain string
|
||||||
|
)
|
||||||
|
|
||||||
|
var sshSftpCmd = &cobra.Command{
|
||||||
|
Use: "sftp",
|
||||||
|
Short: "SFTP server with user switching for Windows (internal use)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: sftpMain,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||||
|
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||||
|
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||||
|
return sftpMainDirect(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sftpMainDirect(cmd *cobra.Command) error {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodeValidationFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
if windowsUsername != "" {
|
||||||
|
expectedUsername := windowsUsername
|
||||||
|
if windowsDomain != "" {
|
||||||
|
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||||
|
cmd.PrintErrf("user switching failed\n")
|
||||||
|
os.Exit(sshserver.ExitCodeValidationFail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||||
|
|
||||||
|
if sftpWorkingDir != "" {
|
||||||
|
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||||
|
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sftpServer, err := sftp.NewServer(struct {
|
||||||
|
io.Reader
|
||||||
|
io.WriteCloser
|
||||||
|
}{
|
||||||
|
Reader: os.Stdin,
|
||||||
|
WriteCloser: os.Stdout,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("starting SFTP server")
|
||||||
|
exitCode := sshserver.ExitCodeSuccess
|
||||||
|
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||||
|
exitCode = sshserver.ExitCodeShellExecFail
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sftpServer.Close(); err != nil {
|
||||||
|
log.Debugf("SFTP server close error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(exitCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
717
client/cmd/ssh_test.go
Normal file
717
client/cmd/ssh_test.go
Normal file
@@ -0,0 +1,717 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedUser string
|
||||||
|
expectedPort int
|
||||||
|
expectedCmd string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic host",
|
||||||
|
args: []string{"hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user@host format",
|
||||||
|
args: []string{"user@hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "user",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with command",
|
||||||
|
args: []string{"hostname", "echo", "hello"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "echo hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with flags should be preserved",
|
||||||
|
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "ls -la /tmp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "double dash separator",
|
||||||
|
args: []string{"hostname", "--", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "-- ls -la",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
// Mock command for testing
|
||||||
|
cmd := sshCmd
|
||||||
|
cmd.SetArgs(tt.args)
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
if tt.expectedUser != "" {
|
||||||
|
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expectedPort, port, "port mismatch")
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||||
|
// Test that SSH flags don't conflict with command flags
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedCmd string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ls with -la flags",
|
||||||
|
args: []string{"hostname", "ls", "-la"},
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
description: "ls flags should be passed to remote command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "grep with -r flag",
|
||||||
|
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
|
||||||
|
expectedCmd: "grep -r pattern /path",
|
||||||
|
description: "grep flags should be passed to remote command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ps with aux flags",
|
||||||
|
args: []string{"hostname", "ps", "aux"},
|
||||||
|
expectedCmd: "ps aux",
|
||||||
|
description: "ps flags should be passed to remote command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with double dash",
|
||||||
|
args: []string{"hostname", "--", "ls", "-la"},
|
||||||
|
expectedCmd: "-- ls -la",
|
||||||
|
description: "double dash should be preserved in command",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||||
|
// Test that commands with arguments should execute the command and exit,
|
||||||
|
// not drop to an interactive shell
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedCmd string
|
||||||
|
shouldExit bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ls command should execute and exit",
|
||||||
|
args: []string{"hostname", "ls"},
|
||||||
|
expectedCmd: "ls",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "ls command should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ls with flags should execute and exit",
|
||||||
|
args: []string{"hostname", "ls", "-la"},
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "ls with flags should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pwd command should execute and exit",
|
||||||
|
args: []string{"hostname", "pwd"},
|
||||||
|
expectedCmd: "pwd",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "pwd command should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "echo command should execute and exit",
|
||||||
|
args: []string{"hostname", "echo", "hello"},
|
||||||
|
expectedCmd: "echo hello",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "echo command should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no command should open shell",
|
||||||
|
args: []string{"hostname"},
|
||||||
|
expectedCmd: "",
|
||||||
|
shouldExit: false,
|
||||||
|
description: "no command should open interactive shell",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
|
||||||
|
// When command is present, it should execute the command and exit
|
||||||
|
// When command is empty, it should open interactive shell
|
||||||
|
hasCommand := command != ""
|
||||||
|
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||||
|
// Test that flags after hostname are not parsed by netbird but passed to SSH command
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedCmd string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ls with -la flag should not be parsed by netbird",
|
||||||
|
args: []string{"debian2", "ls", "-la"},
|
||||||
|
expectedHost: "debian2",
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
expectError: false,
|
||||||
|
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with netbird-like flags should be passed through",
|
||||||
|
args: []string{"hostname", "echo", "--help"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "echo --help",
|
||||||
|
expectError: false,
|
||||||
|
description: "--help should be passed to echo, not parsed by netbird",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with -p flag should not conflict with SSH port flag",
|
||||||
|
args: []string{"hostname", "ps", "-p", "1234"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "ps -p 1234",
|
||||||
|
expectError: false,
|
||||||
|
description: "ps -p should be passed to ps command, not parsed as port",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tar with flags should be passed through",
|
||||||
|
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "tar -czf backup.tar.gz /home",
|
||||||
|
expectError: false,
|
||||||
|
description: "tar flags should be passed to tar command",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||||
|
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
|
||||||
|
// should not parse -la as netbird flags but pass them to the SSH command
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedCmd string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "original issue: ls -la should be preserved",
|
||||||
|
args: []string{"debian2", "ls", "-la"},
|
||||||
|
expectedHost: "debian2",
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
expectError: false,
|
||||||
|
description: "The original failing case should now work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ls -l should be preserved",
|
||||||
|
args: []string{"hostname", "ls", "-l"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "ls -l",
|
||||||
|
expectError: false,
|
||||||
|
description: "Single letter flags should be preserved",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SSH port flag should work",
|
||||||
|
args: []string{"-p", "2222", "hostname", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
expectError: false,
|
||||||
|
description: "SSH -p flag should be parsed, command flags preserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
|
||||||
|
// Check port for the test case with -p flag
|
||||||
|
if len(tt.args) > 0 && tt.args[0] == "-p" {
|
||||||
|
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedLocal []string
|
||||||
|
expectedRemote []string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "local port forwarding -L",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Single -L flag should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remote port forwarding -R",
|
||||||
|
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{},
|
||||||
|
expectedRemote: []string{"8080:localhost:80"},
|
||||||
|
expectError: false,
|
||||||
|
description: "Single -R flag should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple local port forwards",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Multiple -L flags should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple remote port forwards",
|
||||||
|
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{},
|
||||||
|
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||||
|
expectError: false,
|
||||||
|
description: "Multiple -R flags should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed local and remote forwards",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{"9090:localhost:443"},
|
||||||
|
expectError: false,
|
||||||
|
description: "Mixed -L and -R flags should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forwarding with bind address",
|
||||||
|
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Port forwarding with bind address should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forwarding with command",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Port forwarding with command should work",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
localForwards = nil
|
||||||
|
remoteForwards = nil
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
// Handle nil vs empty slice comparison
|
||||||
|
if len(tt.expectedLocal) == 0 {
|
||||||
|
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||||
|
}
|
||||||
|
if len(tt.expectedRemote) == 0 {
|
||||||
|
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePortForward(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
spec string
|
||||||
|
expectedLocal string
|
||||||
|
expectedRemote string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple port forward",
|
||||||
|
spec: "8080:localhost:80",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "localhost:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "Simple port:host:port format should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forward with bind address",
|
||||||
|
spec: "127.0.0.1:8080:localhost:80",
|
||||||
|
expectedLocal: "127.0.0.1:8080",
|
||||||
|
expectedRemote: "localhost:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "bind_address:port:host:port format should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forward to different host",
|
||||||
|
spec: "8080:example.com:443",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "example.com:443",
|
||||||
|
expectError: false,
|
||||||
|
description: "Forwarding to different host should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forward with IPv6 (needs bracket support)",
|
||||||
|
spec: "::1:8080:localhost:80",
|
||||||
|
expectError: true,
|
||||||
|
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid format - too few parts",
|
||||||
|
spec: "8080:localhost",
|
||||||
|
expectError: true,
|
||||||
|
description: "Invalid format with too few parts should fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid format - too many parts",
|
||||||
|
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||||
|
expectError: true,
|
||||||
|
description: "Invalid format with too many parts should fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty spec",
|
||||||
|
spec: "",
|
||||||
|
expectError: true,
|
||||||
|
description: "Empty spec should fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket local forward",
|
||||||
|
spec: "8080:/tmp/socket",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "/tmp/socket",
|
||||||
|
expectError: false,
|
||||||
|
description: "Unix socket forwarding should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket with bind address",
|
||||||
|
spec: "127.0.0.1:8080:/tmp/socket",
|
||||||
|
expectedLocal: "127.0.0.1:8080",
|
||||||
|
expectedRemote: "/tmp/socket",
|
||||||
|
expectError: false,
|
||||||
|
description: "Unix socket with bind address should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard bind all interfaces",
|
||||||
|
spec: "*:8080:localhost:80",
|
||||||
|
expectedLocal: "0.0.0.0:8080",
|
||||||
|
expectedRemote: "localhost:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard for port only",
|
||||||
|
spec: "8080:*:80",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "*:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "Wildcard in remote host should be preserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err, tt.description)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, tt.description)
|
||||||
|
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||||
|
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||||
|
// Integration test for port forwarding with the actual SSH command implementation
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedLocal []string
|
||||||
|
expectedRemote []string
|
||||||
|
expectedCmd string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "local forward with command",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectedCmd: "echo test",
|
||||||
|
description: "Local forwarding should work with commands",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remote forward with command",
|
||||||
|
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{},
|
||||||
|
expectedRemote: []string{"8080:localhost:80"},
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
description: "Remote forwarding should work with commands",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple forwards with user and command",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{"9090:localhost:443"},
|
||||||
|
expectedCmd: "ps aux",
|
||||||
|
description: "Complex case with multiple forwards, user, and command",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
localForwards = nil
|
||||||
|
remoteForwards = nil
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
// Handle nil vs empty slice comparison
|
||||||
|
if len(tt.expectedLocal) == 0 {
|
||||||
|
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||||
|
}
|
||||||
|
if len(tt.expectedRemote) == 0 {
|
||||||
|
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedCmd string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cmd flag passed as command",
|
||||||
|
args: []string{"hostname", "--cmd", "echo test"},
|
||||||
|
expectedCmd: "--cmd echo test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uid flag passed as command",
|
||||||
|
args: []string{"hostname", "--uid", "1000"},
|
||||||
|
expectedCmd: "--uid 1000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "shell flag passed as command",
|
||||||
|
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||||
|
expectedCmd: "--shell /bin/bash",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "hostname", host)
|
||||||
|
assert.Equal(t, tt.expectedCmd, command)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||||
|
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid long flag before hostname",
|
||||||
|
args: []string{"--invalid-flag", "hostname"},
|
||||||
|
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid short flag before hostname",
|
||||||
|
args: []string{"-x", "hostname"},
|
||||||
|
description: "Invalid short flag should return parse error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid flag with value before hostname",
|
||||||
|
args: []string{"--invalid-option=value", "hostname"},
|
||||||
|
description: "Invalid flag with value should return parse error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "typo in known flag",
|
||||||
|
args: []string{"--por", "2222", "hostname"},
|
||||||
|
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||||
|
|
||||||
|
// Should return an error for invalid flags
|
||||||
|
assert.Error(t, err, tt.description)
|
||||||
|
|
||||||
|
// Should not have set host to the invalid flag
|
||||||
|
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -12,8 +12,11 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -23,8 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -115,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil)
|
return connectClient.Run(nil)
|
||||||
@@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
req.ServerSSHAllowed = &serverSSHAllowed
|
req.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
|
req.EnableSSHRoot = &enableSSHRoot
|
||||||
|
}
|
||||||
|
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||||
|
req.EnableSSHSFTP = &enableSSHSFTP
|
||||||
|
}
|
||||||
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
|
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
|
}
|
||||||
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
|
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||||
|
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||||
|
}
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
log.Errorf("parse interface name: %v", err)
|
log.Errorf("parse interface name: %v", err)
|
||||||
@@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
|
ic.EnableSSHRoot = &enableSSHRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||||
|
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
|
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
|
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
|
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||||
|
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
|
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
|
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||||
|
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||||
}
|
}
|
||||||
|
|||||||
13
client/cmd/update.go
Normal file
13
client/cmd/update.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows && !darwin
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var updateCmd *cobra.Command
|
||||||
|
|
||||||
|
func isUpdateBinary() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
75
client/cmd/update_supported.go
Normal file
75
client/cmd/update_supported.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
updateCmd = &cobra.Command{
|
||||||
|
Use: "update",
|
||||||
|
Short: "Update the NetBird client application",
|
||||||
|
RunE: updateFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDirFlag string
|
||||||
|
installerFile string
|
||||||
|
serviceDirFlag string
|
||||||
|
dryRunFlag bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
|
||||||
|
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
|
||||||
|
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
|
||||||
|
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
|
||||||
|
func isUpdateBinary() bool {
|
||||||
|
// Remove extension for cross-platform compatibility
|
||||||
|
execPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
baseName := filepath.Base(execPath)
|
||||||
|
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||||
|
|
||||||
|
return name == installer.UpdaterBinaryNameWithoutExtension()
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateFunc(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := setupLogToFile(tempDirFlag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("updater started: %s", serviceDirFlag)
|
||||||
|
updater := installer.NewWithDir(tempDirFlag)
|
||||||
|
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
|
||||||
|
log.Errorf("failed to update application: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLogToFile(dir string) error {
|
||||||
|
logFile := filepath.Join(dir, installer.LogFile)
|
||||||
|
|
||||||
|
if _, err := os.Stat(logFile); err == nil {
|
||||||
|
if err := os.Remove(logFile); err != nil {
|
||||||
|
log.Errorf("failed to remove existing log file: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.InitLog(logLevel, util.LogConsole, logFile)
|
||||||
|
}
|
||||||
@@ -18,12 +18,16 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
var (
|
||||||
var ErrClientNotStarted = errors.New("client not started")
|
ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrClientNotStarted = errors.New("client not started")
|
||||||
|
ErrEngineNotStarted = errors.New("engine not started")
|
||||||
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -169,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
|||||||
// Dial dials a network address in the netbird network.
|
// Dial dials a network address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
c.mu.Lock()
|
engine, err := c.getEngine()
|
||||||
connect := c.connect
|
if err != nil {
|
||||||
if connect == nil {
|
return nil, err
|
||||||
c.mu.Unlock()
|
|
||||||
return nil, ErrClientNotStarted
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
engine := connect.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil, errors.New("engine not started")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nsnet, err := engine.GetNet()
|
nsnet, err := engine.GetNet()
|
||||||
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
|||||||
return nsnet.DialContext(ctx, network, address)
|
return nsnet.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialContext dials a network address in the netbird network with context
|
||||||
|
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return c.Dial(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
// ListenTCP listens on the given address in the netbird network.
|
// ListenTCP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||||
|
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||||
|
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||||
|
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||||
|
if !found {
|
||||||
|
return sshcommon.ErrPeerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEngine safely retrieves the engine from the client with proper locking.
|
||||||
|
// Returns ErrClientNotStarted if the client is not started.
|
||||||
|
// Returns ErrEngineNotStarted if the engine is not available.
|
||||||
|
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
if connect == nil {
|
if connect == nil {
|
||||||
c.mu.Unlock()
|
return nil, ErrClientNotStarted
|
||||||
return nil, netip.Addr{}, errors.New("client not started")
|
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
return nil, netip.Addr{}, errors.New("engine not started")
|
return nil, ErrEngineNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
return engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := engine.Address()
|
addr, err := engine.Address()
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
tableRaw = "raw"
|
||||||
|
tableSecurity = "security"
|
||||||
|
|
||||||
chainNameNatPrerouting = "PREROUTING"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
var err error
|
var err error
|
||||||
r.filterTable, err = r.loadFilterTable()
|
r.filterTable, err = r.loadFilterTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errFilterTableNotFound) {
|
log.Debugf("ip filter table not found: %v", err)
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("load filter table: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|||||||
return nil, errFilterTableNotFound
|
return nil, errFilterTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hookName(hook *nftables.ChainHook) string {
|
||||||
|
if hook == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
switch *hook {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
return chainNameForward
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
return chainNameInput
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("hook(%d)", *hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyName(family nftables.TableFamily) string {
|
||||||
|
switch family {
|
||||||
|
case nftables.TableFamilyIPv4:
|
||||||
|
return "ip"
|
||||||
|
case nftables.TableFamilyIPv6:
|
||||||
|
return "ip6"
|
||||||
|
case nftables.TableFamilyINet:
|
||||||
|
return "inet"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("family(%d)", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingFw,
|
Name: chainNameRoutingFw,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
})
|
})
|
||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
|
||||||
return fmt.Errorf("add single nat rule: %v", err)
|
r.addPostroutingRules()
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.addMSSClampingRules(); err != nil {
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("initialize tables: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addPostroutingRules adds the masquerade rules
|
// addPostroutingRules adds the masquerade rules
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() {
|
||||||
// First masquerade rule for traffic coming in from WireGuard interface
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
// Match on the first fwmark
|
// Match on the first fwmark
|
||||||
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
Chain: r.chains[chainNameRoutingNat],
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
Exprs: exprs2,
|
Exprs: exprs2,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
|
|||||||
Exprs: exprsOut,
|
Exprs: exprsOut,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return r.conn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||||
func (r *router) acceptForwardRules() error {
|
func (r *router) acceptForwardRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.acceptFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
|||||||
// Try iptables first and fallback to nftables if iptables is not available
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// filter table exists but iptables is not
|
// iptables is not available but the filter table exists
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables()
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.acceptFilterRulesIptables(ipt)
|
return r.acceptFilterRulesIptables(ipt)
|
||||||
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables forward rule: %v", rule)
|
log.Debugf("added iptables forward rule: %v", rule)
|
||||||
}
|
}
|
||||||
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables input rule: %v", inputRule)
|
log.Debugf("added iptables input rule: %v", inputRule)
|
||||||
}
|
}
|
||||||
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
|||||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesNftables() error {
|
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||||
|
// This is used when iptables is not available.
|
||||||
|
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
iifRule := &nftables.Rule{
|
forwardChain := &nftables.Chain{
|
||||||
Table: r.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: chainNameForward,
|
Name: chainNameForward,
|
||||||
Table: r.filterTable,
|
Table: table,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookForward,
|
Hooknum: nftables.ChainHookForward,
|
||||||
Priority: nftables.ChainPriorityFilter,
|
Priority: nftables.ChainPriorityFilter,
|
||||||
},
|
}
|
||||||
|
r.insertForwardAcceptRules(forwardChain, intf)
|
||||||
|
|
||||||
|
inputChain := &nftables.Chain{
|
||||||
|
Name: chainNameInput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertInputAcceptRule(inputChain, intf)
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||||
|
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||||
|
func (r *router) acceptExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
|
iifRule := &nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
inputRule := &nftables.Rule{
|
inputRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameInput,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
UserData: []byte(userDataAcceptInputRule),
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(inputRule)
|
r.conn.InsertRule(inputRule)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRules() error {
|
func (r *router) removeAcceptFilterRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptFilterRulesNftables()
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.removeAcceptFilterRulesIptables(ipt)
|
return r.removeAcceptFilterRulesIptables(ipt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list chains: %v", err)
|
return fmt.Errorf("list chains: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
if chain.Table.Name != table.Name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1100,9 +1188,18 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get rules: %v", err)
|
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -1110,31 +1207,96 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||||
|
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||||
|
// ensuring cleanup works even after a crash or if chains changed.
|
||||||
|
func (r *router) removeExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||||
|
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||||
|
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||||
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
|
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||||
|
|
||||||
|
for _, family := range families {
|
||||||
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("list chains for family %d: %v", family, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range allChains {
|
||||||
|
if r.isExternalChain(chain) {
|
||||||
|
chains = append(chains, chain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return chains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||||
|
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip all iptables-managed tables in the ip family
|
||||||
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Type != nftables.ChainTypeFilter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesTable(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(" unable to list rules: %v", err)
|
return fmt.Errorf("list rules: %w", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
|
|||||||
@@ -35,6 +35,12 @@ const (
|
|||||||
ipTCPHeaderMinSize = 40
|
ipTCPHeaderMinSize = 40
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// serviceKey represents a protocol/port combination for netstack service registry
|
||||||
|
type serviceKey struct {
|
||||||
|
protocol gopacket.LayerType
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
@@ -59,12 +65,6 @@ const (
|
|||||||
|
|
||||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
|
|
||||||
// serviceKey represents a protocol/port combination for netstack service registry
|
|
||||||
type serviceKey struct {
|
|
||||||
protocol gopacket.LayerType
|
|
||||||
port uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
|
|||||||
|
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldForward(t *testing.T) {
|
||||||
|
// Set up test addresses
|
||||||
|
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||||
|
|
||||||
|
// Create test manager with mock interface
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
// Set the mock to return our test WG IP
|
||||||
|
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Helper to create decoder with TCP packet
|
||||||
|
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
SrcIP: net.ParseIP("192.168.1.100"),
|
||||||
|
DstIP: wgIP.AsSlice(),
|
||||||
|
}
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: 54321,
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
|
||||||
|
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
localForwarding bool
|
||||||
|
netstack bool
|
||||||
|
dstIP netip.Addr
|
||||||
|
serviceRegistered bool
|
||||||
|
servicePort uint16
|
||||||
|
expected bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no local forwarding",
|
||||||
|
localForwarding: false,
|
||||||
|
netstack: true,
|
||||||
|
dstIP: wgIP,
|
||||||
|
expected: false,
|
||||||
|
description: "should never forward when local forwarding disabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to other local interface",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: false,
|
||||||
|
dstIP: otherIP,
|
||||||
|
expected: true,
|
||||||
|
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to NetBird IP, no netstack",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: false,
|
||||||
|
dstIP: wgIP,
|
||||||
|
expected: false,
|
||||||
|
description: "should send to netstack listeners (final return false path)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to our IP, netstack mode, no service",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: true,
|
||||||
|
dstIP: wgIP,
|
||||||
|
expected: true,
|
||||||
|
description: "should forward when in netstack mode with no matching service",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to our IP, netstack mode, with service",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: true,
|
||||||
|
dstIP: wgIP,
|
||||||
|
serviceRegistered: true,
|
||||||
|
servicePort: 22,
|
||||||
|
expected: false,
|
||||||
|
description: "should send to netstack listeners when service is registered",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Configure manager
|
||||||
|
manager.localForwarding = tt.localForwarding
|
||||||
|
manager.netstack = tt.netstack
|
||||||
|
|
||||||
|
// Register service if needed
|
||||||
|
if tt.serviceRegistered {
|
||||||
|
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||||
|
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create decoder for the test
|
||||||
|
decoder := createTCPDecoder(tt.servicePort)
|
||||||
|
if !tt.serviceRegistered {
|
||||||
|
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the method
|
||||||
|
result := manager.shouldForward(decoder, tt.dstIP)
|
||||||
|
require.Equal(t, tt.expected, result, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestPortDNATBasic tests basic port DNAT functionality
|
||||||
|
func TestPortDNATBasic(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Define peer IPs
|
||||||
|
peerA := netip.MustParseAddr("100.10.0.50")
|
||||||
|
peerB := netip.MustParseAddr("100.10.0.51")
|
||||||
|
|
||||||
|
// Add SSH port redirection rule for peer B (the target)
|
||||||
|
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||||
|
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||||
|
d := parsePacket(t, packetAtoB)
|
||||||
|
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
|
||||||
|
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||||
|
|
||||||
|
// Verify port was translated to 22022
|
||||||
|
d = parsePacket(t, packetAtoB)
|
||||||
|
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||||
|
|
||||||
|
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
|
||||||
|
// (prevents double NAT - original port stored in conntrack)
|
||||||
|
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||||
|
d2 := parsePacket(t, returnPacket)
|
||||||
|
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
|
||||||
|
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||||
|
func TestPortDNATMultipleRules(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Define peer IPs
|
||||||
|
peerA := netip.MustParseAddr("100.10.0.50")
|
||||||
|
peerB := netip.MustParseAddr("100.10.0.51")
|
||||||
|
|
||||||
|
// Add SSH port redirection rules for both peers
|
||||||
|
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test traffic to peer B gets translated
|
||||||
|
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||||
|
d1 := parsePacket(t, packetToB)
|
||||||
|
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
|
||||||
|
require.True(t, translatedToB, "Traffic to peer B should be translated")
|
||||||
|
d1 = parsePacket(t, packetToB)
|
||||||
|
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
|
||||||
|
|
||||||
|
// Test traffic to peer A gets translated
|
||||||
|
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||||
|
d2 := parsePacket(t, packetToA)
|
||||||
|
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
|
||||||
|
require.True(t, translatedToA, "Traffic to peer A should be translated")
|
||||||
|
d2 = parsePacket(t, packetToA)
|
||||||
|
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,7 +11,6 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -20,9 +18,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
|
||||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
|
||||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
|
||||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
|
||||||
conn.Connect()
|
|
||||||
|
|
||||||
state := conn.GetState()
|
|
||||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
|
||||||
if !conn.WaitForStateChange(ctx, state) {
|
|
||||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
|
||||||
}
|
|
||||||
state = conn.GetState()
|
|
||||||
}
|
|
||||||
|
|
||||||
if state == connectivity.Shutdown {
|
|
||||||
return ErrConnectionShutdown
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
@@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := grpc.NewClient(
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(
|
||||||
|
connCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new client: %w", err)
|
return nil, fmt.Errorf("dial context: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -24,6 +25,7 @@ type WGTunDevice struct {
|
|||||||
key string
|
key string
|
||||||
mtu uint16
|
mtu uint16
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
|
// todo: review if we can eliminate the TunAdapter
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
|
|
||||||
@@ -32,6 +34,7 @@ type WGTunDevice struct {
|
|||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
|
renewableTun *RenewableTUN
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||||
@@ -43,6 +46,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceB
|
|||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
disableDNS: disableDNS,
|
disableDNS: disableDNS,
|
||||||
|
renewableTun: NewRenewableTUN(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||||
|
|
||||||
t.name = name
|
t.name = name
|
||||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
t.filteredDevice = newDeviceFilter(t.renewableTun)
|
||||||
|
|
||||||
log.Debugf("attaching to interface %v", name)
|
log.Debugf("attaching to interface %v", name)
|
||||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||||
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *WGTunDevice) RenewTun(fd int) error {
|
||||||
|
if t.device == nil {
|
||||||
|
return fmt.Errorf("device not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
log.Errorf("failed to renew Android interface: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||||
return t.create()
|
return t.create()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunNetstackDevice) RenewTun(fd int) error {
|
||||||
|
// Doesn't make sense in Android for Netstack.
|
||||||
|
return fmt.Errorf("this function has not been implemented in Netstack for Android")
|
||||||
|
}
|
||||||
|
|||||||
309
client/iface/device/renewable_tun.go
Normal file
309
client/iface/device/renewable_tun.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// closeAwareDevice wraps a tun.Device along with a flag
|
||||||
|
// indicating whether its Close method was called.
|
||||||
|
//
|
||||||
|
// It also redirects tun.Device's Events() to a separate goroutine
|
||||||
|
// and closes it when Close is called.
|
||||||
|
//
|
||||||
|
// The WaitGroup and CloseOnce fields are used to ensure that the
|
||||||
|
// goroutine is awaited and closed only once.
|
||||||
|
type closeAwareDevice struct {
|
||||||
|
isClosed atomic.Bool
|
||||||
|
tun.Device
|
||||||
|
closeEventCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||||
|
return &closeAwareDevice{
|
||||||
|
Device: tunDevice,
|
||||||
|
isClosed: atomic.Bool{},
|
||||||
|
closeEventCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||||
|
// to the given channel (RenewableTUN's events channel).
|
||||||
|
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-c.Device.Events():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ev == tun.EventDown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case out <- ev:
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the underlying Device's Close method
|
||||||
|
// after setting isClosed to true.
|
||||||
|
func (c *closeAwareDevice) Close() (err error) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.isClosed.Store(true)
|
||||||
|
close(c.closeEventCh)
|
||||||
|
err = c.Device.Close()
|
||||||
|
c.wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeAwareDevice) IsClosed() bool {
|
||||||
|
return c.isClosed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
type RenewableTUN struct {
|
||||||
|
devices []*closeAwareDevice
|
||||||
|
mu sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
events chan tun.Event
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRenewableTUN() *RenewableTUN {
|
||||||
|
r := &RenewableTUN{
|
||||||
|
devices: make([]*closeAwareDevice, 0),
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
events: make(chan tun.Event, 16),
|
||||||
|
}
|
||||||
|
r.cond = sync.NewCond(&r.mu)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) File() *os.File {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
file := dev.File()
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from an underlying tun.Device kept in the r.devices slice.
|
||||||
|
// If no device is available, it waits for one to be added via AddDevice().
|
||||||
|
//
|
||||||
|
// On error, it retries reading from the newest device instead of returning the error
|
||||||
|
// if the device is closed; if not, it propagates the error.
|
||||||
|
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
// wait until AddDevice() signals a new device via cond.Broadcast()
|
||||||
|
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = dev.Read(bufs, sizes, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// swap in progress; retry on the newest instead of returning the error
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, err // propagate non-swap error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to underlying tun.Device kept in the r.devices slice.
|
||||||
|
// If no device is available, it waits for one to be added via AddDevice().
|
||||||
|
//
|
||||||
|
// On error, it retries writing to the newest device instead of returning the error
|
||||||
|
// if the device is closed; if not, it propagates the error.
|
||||||
|
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dev.Write(bufs, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) MTU() (int, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mtu, err := dev.MTU()
|
||||||
|
if err == nil {
|
||||||
|
return mtu, nil
|
||||||
|
}
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) Name() (string, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name, err := dev.Name()
|
||||||
|
if err == nil {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events returns a channel that is fed events from the underlying tun.Device's events channel
|
||||||
|
// once it is added.
|
||||||
|
func (r *RenewableTUN) Events() <-chan tun.Event {
|
||||||
|
return r.events
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) Close() error {
|
||||||
|
// Attempts to set the RenewableTUN closed flag to true.
|
||||||
|
// If it's already true, returns immediately.
|
||||||
|
if !r.closed.CompareAndSwap(false, true) {
|
||||||
|
return nil // already closed: idempotent
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
devices := r.devices
|
||||||
|
r.devices = nil
|
||||||
|
r.cond.Broadcast()
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
log.Debugf("closing %d devices", len(devices))
|
||||||
|
for _, device := range devices {
|
||||||
|
if err := device.Close(); err != nil {
|
||||||
|
log.Debugf("error closing a device: %v", err)
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
close(r.events)
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) AddDevice(device tun.Device) {
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.closed.Load() {
|
||||||
|
r.mu.Unlock()
|
||||||
|
_ = device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var toClose *closeAwareDevice
|
||||||
|
if len(r.devices) > 0 {
|
||||||
|
toClose = r.devices[len(r.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cad := newClosableDevice(device)
|
||||||
|
cad.redirectEvents(r.events)
|
||||||
|
|
||||||
|
r.devices = []*closeAwareDevice{cad}
|
||||||
|
r.cond.Broadcast()
|
||||||
|
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
if toClose != nil {
|
||||||
|
if err := toClose.Close(); err != nil {
|
||||||
|
log.Debugf("error closing last device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) waitForDevice() bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
for len(r.devices) == 0 && !r.closed.Load() {
|
||||||
|
r.cond.Wait()
|
||||||
|
}
|
||||||
|
return !r.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) peekLast() *closeAwareDevice {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if len(r.devices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.devices[len(r.devices)-1]
|
||||||
|
}
|
||||||
@@ -21,5 +21,6 @@ type WGTunDevice interface {
|
|||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
Device() *wgdevice.Device
|
Device() *wgdevice.Device
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
|
RenewTun(fd int) error
|
||||||
GetICEBind() device.EndpointManager
|
GetICEBind() device.EndpointManager
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
|
|||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
return fmt.Errorf("this function has not implemented on non mobile")
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
return fmt.Errorf("this function has not been implemented on non-android")
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
|
// todo: review does this function really necessary or can we merge it with iOS
|
||||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
|||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
return w.tun.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
|
|||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
return fmt.Errorf("this function has not been implemented on this platform")
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules := networkMap.FirewallRules
|
rules := networkMap.FirewallRules
|
||||||
|
|
||||||
enableSSH := networkMap.PeerConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
|
||||||
|
|
||||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
|
||||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
|
||||||
if enableSSH {
|
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||||
// we have old version of management without rules handling, we should allow all traffic
|
// we have old version of management without rules handling, we should allow all traffic
|
||||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
||||||
|
|||||||
@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
PeerConfig: &mgmProto.PeerConfig{
|
|
||||||
SshConfig: &mgmProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
|
||||||
IP: network.Addr(),
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
err = fw.Close(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
|
||||||
|
|
||||||
expectedRules := 3
|
|
||||||
if fw.IsStateful() {
|
|
||||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
|
||||||
}
|
|
||||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
return t.AccessToken
|
return t.AccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||||
|
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||||
|
}
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||||
//
|
//
|
||||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||||
|
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||||
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
|
|||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
// find the first available redirect URL
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
|
|
||||||
for _, redirectURL := range config.RedirectURLs {
|
for _, redirectURL := range config.RedirectURLs {
|
||||||
if !isRedirectURLPortUsed(redirectURL) {
|
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||||
availableRedirectURL = redirectURL
|
availableRedirectURL = redirectURL
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
switch p.providerConfig.LoginFlag {
|
||||||
|
case common.LoginFlagPromptLogin:
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
}
|
case common.LoginFlagMaxAge0:
|
||||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
|
||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,17 +195,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
|||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
if authError := query.Get(queryError); authError != "" {
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
if authErrorDesc != "" {
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", authError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent timing attacks on the state
|
// Prevent timing attacks on the state
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
return nil, fmt.Errorf("invalid state")
|
return nil, fmt.Errorf("authentication failed: Invalid state")
|
||||||
}
|
}
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
code := query.Get(queryCode)
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, fmt.Errorf("missing code")
|
return nil, fmt.Errorf("authentication failed: missing code")
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
return p.oAuthConfig.Exchange(
|
||||||
@@ -231,7 +237,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||||
@@ -279,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
|
|||||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||||
parsedURL, err := url.Parse(redirectURL)
|
parsedURL, err := url.Parse(redirectURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse redirect URL: %v", err)
|
log.Errorf("failed to parse redirect URL: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
port := parsedURL.Port()
|
||||||
|
|
||||||
|
if isPortInExcludedRange(port, excludedRanges) {
|
||||||
|
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf(":%s", port)
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
@@ -301,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// excludedPortRange represents a range of excluded ports.
|
||||||
|
type excludedPortRange struct {
|
||||||
|
start int
|
||||||
|
end int
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||||
|
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||||
|
if len(excludedRanges) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid port number %s: %v", port, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range excludedRanges {
|
||||||
|
if portNum >= r.start && portNum <= r.end {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
8
client/internal/auth/pkce_flow_other.go
Normal file
8
client/internal/auth/pkce_flow_other.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,8 +2,11 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
loginFlag mgm.LoginFlag
|
loginFlag mgm.LoginFlag
|
||||||
disablePromptLogin bool
|
disablePromptLogin bool
|
||||||
expect string
|
expectContains []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Prompt login",
|
name: "Prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
expect: promptLogin,
|
expectContains: []string{promptLogin},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Max age 0 login",
|
name: "Max age 0",
|
||||||
loginFlag: mgm.LoginFlagMaxAge0,
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
expect: maxAge0,
|
expectContains: []string{maxAge0},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disable prompt login",
|
name: "Disable prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
disablePromptLogin: true,
|
disablePromptLogin: true,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "None flag should not add parameters",
|
||||||
|
loginFlag: mgm.LoginFlagNone,
|
||||||
|
expectContains: []string{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
LoginFlag: tc.loginFlag,
|
LoginFlag: tc.loginFlag,
|
||||||
|
DisablePromptLogin: tc.disablePromptLogin,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tc.disablePromptLogin {
|
for _, expected := range tc.expectContains {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||||
} else {
|
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPortInExcludedRange(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
port string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedBlocked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at start of range",
|
||||||
|
port: "8000",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at end of range",
|
||||||
|
port: "8100",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port before range",
|
||||||
|
port: "7999",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port after range",
|
||||||
|
port: "8101",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port in second range",
|
||||||
|
port: "9050",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port not in any range",
|
||||||
|
port: "8500",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid port string",
|
||||||
|
port: "invalid",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty port string",
|
||||||
|
port: "",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURL string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedUsed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
redirectURL: "http://127.0.0.1:8080/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port actually in use",
|
||||||
|
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port not in use and not excluded",
|
||||||
|
redirectURL: "http://127.0.0.1:65432/",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid URL without port",
|
||||||
|
redirectURL: "not-a-valid-url",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port excluded even if not in use",
|
||||||
|
redirectURL: "http://127.0.0.1:8050/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
86
client/internal/auth/pkce_flow_windows.go
Normal file
86
client/internal/auth/pkce_flow_windows.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
ranges, err := getExcludedPortRangesFromNetsh()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||||
|
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("netsh command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseExcludedPortRanges(string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||||
|
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||||
|
var ranges []excludedPortRange
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||||
|
|
||||||
|
foundHeader := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
|
||||||
|
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||||
|
foundHeader = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundHeader {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(line, "----------") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort, err := strconv.Atoi(fields[0])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
endPort, err := strconv.Atoi(fields[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges, nil
|
||||||
|
}
|
||||||
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
netshOutput string
|
||||||
|
expectedRanges []excludedPortRange
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid netsh output with multiple ranges",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
5357 5357 *
|
||||||
|
50000 50059 *
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 5357, end: 5357},
|
||||||
|
{start: 50000, end: 50059},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty output",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
`,
|
||||||
|
expectedRanges: nil,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single range",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
8080 8090
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 8080, end: 8090},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedRanges, ranges)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||||
|
ranges := getSystemExcludedPortRanges()
|
||||||
|
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||||
|
|
||||||
|
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener1.Close()
|
||||||
|
}()
|
||||||
|
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
availablePort := 65432
|
||||||
|
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||||
|
},
|
||||||
|
UseIDToken: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, flow)
|
||||||
|
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||||
|
"Should skip port in use and select available port")
|
||||||
|
}
|
||||||
@@ -24,10 +24,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -42,6 +46,8 @@ type ConnectClient struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
doInitialAutoUpdate bool
|
||||||
|
|
||||||
engine *Engine
|
engine *Engine
|
||||||
engineMutex sync.Mutex
|
engineMutex sync.Mutex
|
||||||
|
|
||||||
@@ -52,12 +58,14 @@ func NewConnectClient(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
doInitalAutoUpdate bool,
|
||||||
|
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||||
engineMutex: sync.Mutex{},
|
engineMutex: sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,6 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -82,6 +91,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
@@ -160,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||||
|
// On mobile, use the provided state file path directly
|
||||||
|
if !fileExists(mobileDependency.StateFilePath) {
|
||||||
|
if err := createFile(mobileDependency.StateFilePath); err != nil {
|
||||||
|
log.Errorf("failed to create state file: %v", err)
|
||||||
|
// we are not exiting as we can run without the state manager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
path = mobileDependency.StateFilePath
|
||||||
|
} else {
|
||||||
|
sm := profilemanager.NewServiceManager("")
|
||||||
|
path = sm.GetStatePath()
|
||||||
|
}
|
||||||
|
stateManager := statemanager.New(path)
|
||||||
|
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||||
|
|
||||||
|
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||||
|
if err == nil {
|
||||||
|
updateManager.CheckUpdateSuccess(c.ctx)
|
||||||
|
|
||||||
|
inst := installer.New()
|
||||||
|
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||||
|
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
@@ -271,15 +308,25 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||||
|
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||||
|
if c.doInitialAutoUpdate {
|
||||||
|
log.Infof("start engine by ui, run auto-update check")
|
||||||
|
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||||
|
c.doInitialAutoUpdate = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
@@ -291,12 +338,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := c.engine
|
|
||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if engine != nil && engine.wgInterface != nil {
|
// todo: consider to remove this condition. Is not thread safe.
|
||||||
|
// We should always call Stop(), but we need to verify that it is idempotent
|
||||||
|
if engine.wgInterface != nil {
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
@@ -429,6 +478,11 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
RosenpassEnabled: config.RosenpassEnabled,
|
RosenpassEnabled: config.RosenpassEnabled,
|
||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
|
EnableSSHRoot: config.EnableSSHRoot,
|
||||||
|
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||||
|
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||||
|
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||||
|
DisableSSHAuth: config.DisableSSHAuth,
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
@@ -515,6 +569,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
|
config.EnableSSHRoot,
|
||||||
|
config.EnableSSHSFTP,
|
||||||
|
config.EnableSSHLocalPortForwarding,
|
||||||
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
@@ -56,6 +57,7 @@ block.prof: Block profiling information.
|
|||||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||||
allocs.prof: Allocations profiling information.
|
allocs.prof: Allocations profiling information.
|
||||||
threadcreate.prof: Thread creation profiling information.
|
threadcreate.prof: Thread creation profiling information.
|
||||||
|
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||||
|
|
||||||
|
|
||||||
Anonymization Process
|
Anonymization Process
|
||||||
@@ -109,6 +111,9 @@ go tool pprof -http=:8088 heap.prof
|
|||||||
|
|
||||||
This will open a web browser tab with the profiling information.
|
This will open a web browser tab with the profiling information.
|
||||||
|
|
||||||
|
Stack Trace
|
||||||
|
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
|
||||||
|
|
||||||
Routes
|
Routes
|
||||||
The routes.txt file contains detailed routing table information in a tabular format:
|
The routes.txt file contains detailed routing table information in a tabular format:
|
||||||
|
|
||||||
@@ -327,6 +332,10 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addStackTrace(); err != nil {
|
||||||
|
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.addSyncResponse(); err != nil {
|
if err := g.addSyncResponse(); err != nil {
|
||||||
return fmt.Errorf("add sync response: %w", err)
|
return fmt.Errorf("add sync response: %w", err)
|
||||||
}
|
}
|
||||||
@@ -354,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add systemd logs: %v", err)
|
log.Errorf("failed to add systemd logs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
|
log.Errorf("failed to add updater logs: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -453,6 +466,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.ServerSSHAllowed != nil {
|
if g.internalConfig.ServerSSHAllowed != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||||
}
|
}
|
||||||
|
if g.internalConfig.EnableSSHRoot != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
|
||||||
|
}
|
||||||
|
if g.internalConfig.EnableSSHSFTP != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
|
||||||
|
}
|
||||||
|
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
|
||||||
|
}
|
||||||
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
@@ -510,6 +535,18 @@ func (g *BundleGenerator) addProf() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addStackTrace() error {
|
||||||
|
buf := make([]byte, 5242880) // 5 MB buffer
|
||||||
|
n := runtime.Stack(buf, true)
|
||||||
|
|
||||||
|
stackTrace := bytes.NewReader(buf[:n])
|
||||||
|
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add stack trace file to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addInterfaces() error {
|
func (g *BundleGenerator) addInterfaces() error {
|
||||||
interfaces, err := net.Interfaces()
|
interfaces, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -618,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addUpdateLogs() error {
|
||||||
|
inst := installer.New()
|
||||||
|
logFiles := inst.LogFiles()
|
||||||
|
if len(logFiles) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("adding updater logs")
|
||||||
|
for _, logFile := range logFiles {
|
||||||
|
data, err := os.ReadFile(logFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to read update log file %s: %v", logFile, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
baseName := filepath.Base(logFile)
|
||||||
|
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
|
||||||
|
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addCorruptedStateFiles() error {
|
func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
pattern := sm.GetStatePath()
|
pattern := sm.GetStatePath()
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
|||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
|
if zone.SkipPTRProcess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
if record.Type != int(dns.TypeA) {
|
if record.Type != int(dns.TypeA) {
|
||||||
continue
|
continue
|
||||||
@@ -108,6 +111,7 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
|||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
Records: records,
|
Records: records,
|
||||||
|
SearchDomainDisabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||||
|
|||||||
@@ -11,11 +11,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4ReverseZone = ".in-addr.arpa."
|
|
||||||
ipv6ReverseZone = ".ip6.arpa."
|
|
||||||
)
|
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: customZone.SearchDomainDisabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -26,6 +27,11 @@ type Resolver struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ipsResponse struct {
|
||||||
|
ips []netip.Addr
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var aRecords, aaaaRecords []dns.RR
|
var aRecords, aaaaRecords []dns.RR
|
||||||
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||||
|
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
|
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
|
resultChan := make(chan *ipsResponse, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||||
|
resultChan <- &ipsResponse{
|
||||||
|
err: err,
|
||||||
|
ips: ips,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var resp *ipsResponse
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||||
|
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||||
|
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp = <-resultChan:
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||||
|
}
|
||||||
|
return resp.ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
||||||
if mgmtURL == nil {
|
if mgmtURL == nil {
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
timeoutMsg += " " + peerInfo
|
timeoutMsg += " " + peerInfo
|
||||||
}
|
}
|
||||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||||
logger.Warnf(timeoutMsg)
|
logger.Warn(timeoutMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
|
|||||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||||
|
for i, ip := range ips {
|
||||||
|
ips[i] = ip.Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -30,7 +29,6 @@ import (
|
|||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
@@ -44,17 +42,16 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -75,6 +72,7 @@ const (
|
|||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
connInitLimit = 200
|
connInitLimit = 200
|
||||||
|
disableAutoUpdate = "disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
@@ -116,6 +114,11 @@ type EngineConfig struct {
|
|||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
|
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
EnableSSHRoot *bool
|
||||||
|
EnableSSHSFTP *bool
|
||||||
|
EnableSSHLocalPortForwarding *bool
|
||||||
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
@@ -148,8 +151,6 @@ type Engine struct {
|
|||||||
|
|
||||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||||
syncMsgMux *sync.Mutex
|
syncMsgMux *sync.Mutex
|
||||||
// sshMux protects sshServer field access
|
|
||||||
sshMux sync.Mutex
|
|
||||||
|
|
||||||
config *EngineConfig
|
config *EngineConfig
|
||||||
mobileDep MobileDependency
|
mobileDep MobileDependency
|
||||||
@@ -175,8 +176,7 @@ type Engine struct {
|
|||||||
|
|
||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServer sshServer
|
||||||
sshServer nbssh.Server
|
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -201,6 +201,9 @@ type Engine struct {
|
|||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
|
// auto-update
|
||||||
|
updateManager *updatemanager.Manager
|
||||||
|
|
||||||
// WireGuard interface monitor
|
// WireGuard interface monitor
|
||||||
wgIfaceMonitor *WGIfaceMonitor
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
|
|
||||||
@@ -221,17 +224,7 @@ type localIpUpdater interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine with probes attached
|
// NewEngine creates a new Connection Engine with probes attached
|
||||||
func NewEngine(
|
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
|
||||||
clientCtx context.Context,
|
|
||||||
clientCancel context.CancelFunc,
|
|
||||||
signalClient signal.Client,
|
|
||||||
mgmClient mgm.Client,
|
|
||||||
relayManager *relayClient.Manager,
|
|
||||||
config *EngineConfig,
|
|
||||||
mobileDep MobileDependency,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
checks []*mgmProto.Checks,
|
|
||||||
) *Engine {
|
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
@@ -246,29 +239,13 @@ func NewEngine(
|
|||||||
STUNs: []*stun.URI{},
|
STUNs: []*stun.URI{},
|
||||||
TURNs: []*stun.URI{},
|
TURNs: []*stun.URI{},
|
||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
stateManager: stateManager,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager("")
|
|
||||||
|
|
||||||
path := sm.GetStatePath()
|
|
||||||
if runtime.GOOS == "ios" {
|
|
||||||
if !fileExists(mobileDep.StateFilePath) {
|
|
||||||
err := createFile(mobileDep.StateFilePath)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create state file: %v", err)
|
|
||||||
// we are not exiting as we can run without the state manager
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
path = mobileDep.StateFilePath
|
|
||||||
}
|
|
||||||
engine.stateManager = statemanager.New(path)
|
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
return engine
|
return engine
|
||||||
}
|
}
|
||||||
@@ -280,7 +257,6 @@ func (e *Engine) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
@@ -292,8 +268,11 @@ func (e *Engine) Stop() error {
|
|||||||
}
|
}
|
||||||
log.Info("Network monitor: stopped")
|
log.Info("Network monitor: stopped")
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
if err := e.stopSSHServer(); err != nil {
|
||||||
e.stopDNSServer()
|
log.Warnf("failed to stop SSH server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
@@ -302,24 +281,33 @@ func (e *Engine) Stop() error {
|
|||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
e.stopDNSForwarder()
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop(e.stateManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.updateManager != nil {
|
||||||
|
e.updateManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
if err := e.removeAllPeers(); err != nil {
|
if err := e.removeAllPeers(); err != nil {
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
log.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop(e.stateManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.stopDNSForwarder()
|
||||||
|
|
||||||
|
// stop/restore DNS after peers are closed but before interface goes down
|
||||||
|
// so dbus and friends don't complain because of a missing interface
|
||||||
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@@ -331,16 +319,18 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer stateCancel()
|
||||||
|
|
||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
log.Errorf("failed to stop state manager: %v", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
timeout := e.calculateShutdownTimeout()
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
@@ -426,8 +416,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
err := e.rpManager.Run()
|
if err := e.rpManager.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -479,6 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -532,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) createFirewall() error {
|
func (e *Engine) createFirewall() error {
|
||||||
if e.config.DisableFirewall {
|
if e.config.DisableFirewall {
|
||||||
log.Infof("firewall is disabled")
|
log.Infof("firewall is disabled")
|
||||||
@@ -703,16 +700,10 @@ func (e *Engine) removeAllPeers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
// removePeer closes an existing peer connection and removes a peer
|
||||||
func (e *Engine) removePeer(peerKey string) error {
|
func (e *Engine) removePeer(peerKey string) error {
|
||||||
log.Debugf("removing peer from engine %s", peerKey)
|
log.Debugf("removing peer from engine %s", peerKey)
|
||||||
|
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
|
|
||||||
e.connMgr.RemovePeerConn(peerKey)
|
e.connMgr.RemovePeerConn(peerKey)
|
||||||
|
|
||||||
err := e.statusRecorder.RemovePeer(peerKey)
|
err := e.statusRecorder.RemovePeer(peerKey)
|
||||||
@@ -746,10 +737,54 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||||
|
if autoUpdateSettings == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||||
|
|
||||||
|
// Stop and cleanup if disabled
|
||||||
|
if e.updateManager != nil && disabled {
|
||||||
|
log.Infof("auto-update is disabled, stopping update manager")
|
||||||
|
e.updateManager.Stop()
|
||||||
|
e.updateManager = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||||
|
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||||
|
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start manager if needed
|
||||||
|
if e.updateManager == nil {
|
||||||
|
log.Infof("starting auto-update manager")
|
||||||
|
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.updateManager = updateManager
|
||||||
|
e.updateManager.Start(e.ctx)
|
||||||
|
}
|
||||||
|
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||||
|
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||||
|
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||||
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
wCfg := update.GetNetbirdConfig()
|
wCfg := update.GetNetbirdConfig()
|
||||||
err := e.updateTURNs(wCfg.GetTurns())
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
@@ -884,6 +919,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.BlockLANAccess,
|
e.config.BlockLANAccess,
|
||||||
e.config.BlockInbound,
|
e.config.BlockInbound,
|
||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
|
e.config.EnableSSHRoot,
|
||||||
|
e.config.EnableSSHSFTP,
|
||||||
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -893,74 +933,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(server nbssh.Server) bool {
|
|
||||||
return server == nil || reflect.ValueOf(server).IsNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|
||||||
if e.config.BlockInbound {
|
|
||||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !e.config.ServerSSHAllowed {
|
|
||||||
log.Info("SSH server is not enabled")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if sshConf.GetSshEnabled() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
e.sshMux.Lock()
|
|
||||||
// start SSH server if it wasn't running
|
|
||||||
if isNil(e.sshServer) {
|
|
||||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
|
||||||
if nbnetstack.IsEnabled() {
|
|
||||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
|
||||||
}
|
|
||||||
// nil sshServer means it has not yet been started
|
|
||||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
|
||||||
if err != nil {
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
return fmt.Errorf("create ssh server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.sshServer = server
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// blocking
|
|
||||||
err = server.Start()
|
|
||||||
if err != nil {
|
|
||||||
// will throw error when we stop it even if it is a graceful stop
|
|
||||||
log.Debugf("stopped SSH server with error %v", err)
|
|
||||||
}
|
|
||||||
e.sshMux.Lock()
|
|
||||||
e.sshServer = nil
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
log.Infof("stopped SSH server")
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
log.Debugf("SSH server is already running")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
// Disable SSH server request, so stop it if it was running
|
|
||||||
err := e.sshServer.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to stop SSH server %v", err)
|
|
||||||
}
|
|
||||||
e.sshServer = nil
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return errors.New("wireguard interface is not initialized")
|
return errors.New("wireguard interface is not initialized")
|
||||||
@@ -973,8 +945,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if conf.GetSshConfig() != nil {
|
if conf.GetSshConfig() != nil {
|
||||||
err := e.updateSSH(conf.GetSshConfig())
|
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed handling SSH server setup: %v", err)
|
log.Warnf("failed handling SSH server setup: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1012,6 +983,11 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.BlockLANAccess,
|
e.config.BlockLANAccess,
|
||||||
e.config.BlockInbound,
|
e.config.BlockInbound,
|
||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
|
e.config.EnableSSHRoot,
|
||||||
|
e.config.EnableSSHSFTP,
|
||||||
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
@@ -1170,20 +1146,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
// update SSHServer by adding remote peer SSH keys
|
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||||
for _, config := range networkMap.GetRemotePeers() {
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
|
||||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
|
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
|
||||||
@@ -1259,6 +1227,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||||
|
//nolint
|
||||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||||
if forwarderPort == 0 {
|
if forwarderPort == 0 {
|
||||||
forwarderPort = nbdns.ForwarderClientPort
|
forwarderPort = nbdns.ForwarderClientPort
|
||||||
@@ -1274,6 +1243,8 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
|||||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||||
dnsZone := nbdns.CustomZone{
|
dnsZone := nbdns.CustomZone{
|
||||||
Domain: zone.GetDomain(),
|
Domain: zone.GetDomain(),
|
||||||
|
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||||
|
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||||
}
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
dnsRecord := nbdns.SimpleRecord{
|
dnsRecord := nbdns.SimpleRecord{
|
||||||
@@ -1433,6 +1404,11 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
@@ -1544,15 +1520,6 @@ func (e *Engine) close() {
|
|||||||
e.statusRecorder.SetWgIface(nil)
|
e.statusRecorder.SetWgIface(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
err := e.sshServer.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed stopping the SSH server: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Close(e.stateManager)
|
err := e.firewall.Close(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1583,6 +1550,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.BlockLANAccess,
|
e.config.BlockLANAccess,
|
||||||
e.config.BlockInbound,
|
e.config.BlockInbound,
|
||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
|
e.config.EnableSSHRoot,
|
||||||
|
e.config.EnableSSHSFTP,
|
||||||
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
@@ -1901,6 +1873,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
|||||||
return e.wgInterface.Address().IP
|
return e.wgInterface.Address().IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) RenewTun(fd int) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
wgInterface := e.wgInterface
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if wgInterface == nil {
|
||||||
|
return fmt.Errorf("wireguard interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgInterface.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|
||||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
func (e *Engine) updateDNSForwarder(
|
func (e *Engine) updateDNSForwarder(
|
||||||
enabled bool,
|
enabled bool,
|
||||||
|
|||||||
355
client/internal/engine_ssh.go
Normal file
355
client/internal/engine_ssh.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sshServer interface {
|
||||||
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
|
Stop() error
|
||||||
|
GetStatus() (bool, []sshserver.SessionInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||||
|
return fmt.Errorf("add SSH port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||||
|
return e.stopSSHServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.config.ServerSSHAllowed {
|
||||||
|
log.Info("SSH server is disabled in config")
|
||||||
|
return e.stopSSHServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sshConf.GetSshEnabled() {
|
||||||
|
if e.config.ServerSSHAllowed {
|
||||||
|
log.Info("SSH server is locally allowed but disabled by management server")
|
||||||
|
}
|
||||||
|
return e.stopSSHServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.sshServer != nil {
|
||||||
|
log.Debug("SSH server is already running")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||||
|
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||||
|
return e.startSSHServer(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||||
|
jwtConfig := &sshserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audience: protoJWT.GetAudience(),
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.startSSHServer(jwtConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("SSH server requires valid JWT configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
|
if len(peerInfo) == 0 {
|
||||||
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
configMgr := sshconfig.New()
|
||||||
|
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||||
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
|
return nil // Don't fail engine startup on SSH config issues
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||||
|
|
||||||
|
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||||
|
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||||
|
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||||
|
}); err != nil {
|
||||||
|
log.Warnf("failed to update SSH config state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||||
|
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||||
|
var peerInfo []sshconfig.PeerSSHInfo
|
||||||
|
|
||||||
|
for _, peerConfig := range remotePeers {
|
||||||
|
if peerConfig.GetSshConfig() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||||
|
if len(sshPubKeyBytes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIP := e.extractPeerIP(peerConfig)
|
||||||
|
hostname := e.extractHostname(peerConfig)
|
||||||
|
|
||||||
|
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||||
|
Hostname: hostname,
|
||||||
|
IP: peerIP,
|
||||||
|
FQDN: peerConfig.GetFqdn(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||||
|
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||||
|
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||||
|
return prefix.Addr().String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractHostname extracts short hostname from FQDN
|
||||||
|
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||||
|
fqdn := peerConfig.GetFqdn()
|
||||||
|
if fqdn == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(fqdn, ".")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||||
|
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||||
|
for _, peerConfig := range remotePeers {
|
||||||
|
if peerConfig.GetSshConfig() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||||
|
if len(sshPubKeyBytes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
|
||||||
|
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||||
|
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
statusRecorder := e.statusRecorder
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if statusRecorder == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus := statusRecorder.GetFullStatus()
|
||||||
|
for _, peerState := range fullStatus.Peers {
|
||||||
|
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||||
|
if len(peerState.SSHHostKey) > 0 {
|
||||||
|
return peerState.SSHHostKey, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
log.Warnf("failed to remove SSH client config: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("SSH client config cleanup completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||||
|
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wg interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConfig := &sshserver.Config{
|
||||||
|
HostKeyPEM: e.config.SSHKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
|
|
||||||
|
wgAddr := e.wgInterface.Address()
|
||||||
|
server.SetNetworkValidation(wgAddr)
|
||||||
|
|
||||||
|
netbirdIP := wgAddr.IP
|
||||||
|
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
server.SetNetstackNet(netstackNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.configureSSHServer(server)
|
||||||
|
|
||||||
|
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||||
|
return fmt.Errorf("start SSH server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.sshServer = server
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||||
|
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.setupSSHPortRedirection(); err != nil {
|
||||||
|
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureSSHServer applies SSH configuration options to the server.
|
||||||
|
func (e *Engine) configureSSHServer(server *sshserver.Server) {
|
||||||
|
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
log.Info("SSH root login enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowRootLogin(false)
|
||||||
|
log.Info("SSH root login disabled (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
|
||||||
|
server.SetAllowSFTP(true)
|
||||||
|
log.Info("SSH SFTP subsystem enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowSFTP(false)
|
||||||
|
log.Info("SSH SFTP subsystem disabled (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
|
||||||
|
server.SetAllowLocalPortForwarding(true)
|
||||||
|
log.Info("SSH local port forwarding enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowLocalPortForwarding(false)
|
||||||
|
log.Info("SSH local port forwarding disabled (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
|
||||||
|
server.SetAllowRemotePortForwarding(true)
|
||||||
|
log.Info("SSH remote port forwarding enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowRemotePortForwarding(false)
|
||||||
|
log.Info("SSH remote port forwarding disabled (default)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) cleanupSSHPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||||
|
return fmt.Errorf("remove SSH port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopSSHServer() error {
|
||||||
|
if e.sshServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.cleanupSSHPortRedirection(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup SSH port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||||
|
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("stopping SSH server")
|
||||||
|
err := e.sshServer.Stop()
|
||||||
|
e.sshServer = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSSHServerStatus returns the SSH server status and active sessions
|
||||||
|
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
sshServer := e.sshServer
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if sshServer == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshServer.GetStatus()
|
||||||
|
}
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -25,14 +24,18 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -46,13 +49,12 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -108,6 +110,10 @@ type MockWGIface struct {
|
|||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) RenewTun(_ int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -214,11 +220,13 @@ func TestMain(m *testing.M) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_SSH(t *testing.T) {
|
func TestEngine_SSH(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
t.Skip("skipping TestEngine_SSH")
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@@ -240,45 +248,20 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
ServerSSHAllowed: true,
|
ServerSSHAllowed: true,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
|
SSHKey: sshKey,
|
||||||
},
|
},
|
||||||
MobileDependency{},
|
MobileDependency{},
|
||||||
peer.NewRecorder("https://mgm"),
|
peer.NewRecorder("https://mgm"),
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
var sshKeysAdded []string
|
|
||||||
var sshPeersRemoved []string
|
|
||||||
|
|
||||||
sshCtx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
|
|
||||||
return &ssh.MockServer{
|
|
||||||
Ctx: sshCtx,
|
|
||||||
StopFunc: func() error {
|
|
||||||
cancel()
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
StartFunc: func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
return ctx.Err()
|
|
||||||
},
|
|
||||||
AddAuthorizedKeyFunc: func(peer, newKey string) error {
|
|
||||||
sshKeysAdded = append(sshKeysAdded, newKey)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
RemoveAuthorizedKeyFunc: func(peer string) {
|
|
||||||
sshPeersRemoved = append(sshPeersRemoved, peer)
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
err = engine.Start(nil, nil)
|
err = engine.Start(nil, nil)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := engine.Stop()
|
err := engine.Stop()
|
||||||
@@ -304,9 +287,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
@@ -314,19 +295,24 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
Serial: 7,
|
Serial: 7,
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshEnabled: true,
|
||||||
|
JwtConfig: &mgmtProto.JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
},
|
||||||
|
}},
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
assert.NotNil(t, engine.sshServer)
|
assert.NotNil(t, engine.sshServer)
|
||||||
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
|
|
||||||
|
|
||||||
// now remove peer
|
// now remove peer
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
@@ -336,13 +322,10 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
// time.Sleep(250 * time.Millisecond)
|
||||||
assert.NotNil(t, engine.sshServer)
|
assert.NotNil(t, engine.sshServer)
|
||||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
|
||||||
|
|
||||||
// now disable SSH server
|
// now disable SSH server
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
@@ -354,12 +337,70 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||||
|
// Test that SSH server start/stop logic works based on config
|
||||||
|
engine := &Engine{
|
||||||
|
config: &EngineConfig{
|
||||||
|
ServerSSHAllowed: false, // Start with SSH disabled
|
||||||
|
},
|
||||||
|
syncMsgMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SSH disabled config
|
||||||
|
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
|
||||||
|
err := engine.updateSSH(sshConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// Test inbound blocked
|
||||||
|
engine.config.BlockInbound = true
|
||||||
|
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
engine.config.BlockInbound = false
|
||||||
|
|
||||||
|
// Test with server SSH not allowed
|
||||||
|
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_SSHServerConsistency(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("server set only on successful creation", func(t *testing.T) {
|
||||||
|
engine := &Engine{
|
||||||
|
config: &EngineConfig{
|
||||||
|
ServerSSHAllowed: true,
|
||||||
|
SSHKey: []byte("test-key"),
|
||||||
|
},
|
||||||
|
syncMsgMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.wgInterface = nil
|
||||||
|
|
||||||
|
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
|
||||||
|
engine := &Engine{
|
||||||
|
config: &EngineConfig{
|
||||||
|
ServerSSHAllowed: false,
|
||||||
|
},
|
||||||
|
syncMsgMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := engine.stopSSHServer()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||||
@@ -374,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(
|
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||||
ctx, cancel,
|
|
||||||
&signal.MockClient{},
|
|
||||||
&mgmt.MockClient{},
|
|
||||||
relayMgr,
|
|
||||||
&EngineConfig{
|
|
||||||
WgIfaceName: "utun102",
|
WgIfaceName: "utun102",
|
||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
},
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
MobileDependency{},
|
|
||||||
peer.NewRecorder("https://mgm"),
|
|
||||||
nil)
|
|
||||||
|
|
||||||
wgIface := &MockWGIface{
|
wgIface := &MockWGIface{
|
||||||
NameFunc: func() string { return "utun102" },
|
NameFunc: func() string { return "utun102" },
|
||||||
@@ -607,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -772,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -974,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
@@ -1500,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||||
e.ctx = ctx
|
e.ctx = ctx
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
@@ -1588,14 +1621,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
type wgIfaceBase interface {
|
type wgIfaceBase interface {
|
||||||
Create() error
|
Create() error
|
||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||||
|
RenewTun(fd int) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
|||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
|
config.EnableSSHRoot,
|
||||||
|
config.EnableSSHSFTP,
|
||||||
|
config.EnableSSHLocalPortForwarding,
|
||||||
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
return serverKey, loginResp, err
|
return serverKey, loginResp, err
|
||||||
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
|
config.EnableSSHRoot,
|
||||||
|
config.EnableSSHSFTP,
|
||||||
|
config.EnableSSHLocalPortForwarding,
|
||||||
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type EndpointUpdater struct {
|
|||||||
wgConfig WgConfig
|
wgConfig WgConfig
|
||||||
initiator bool
|
initiator bool
|
||||||
|
|
||||||
// mu protects updateWireGuardPeer and cancelFunc
|
// mu protects cancelFunc
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cancelFunc func()
|
cancelFunc func()
|
||||||
updateWg sync.WaitGroup
|
updateWg sync.WaitGroup
|
||||||
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
e.mu.Lock()
|
|
||||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||||
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||||
}
|
}
|
||||||
e.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,5 +11,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func isForceRelayed() bool {
|
func isForceRelayed() bool {
|
||||||
|
if runtime.GOOS == "js" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventQueueSize = 10
|
const eventQueueSize = 10
|
||||||
@@ -67,6 +67,7 @@ type State struct {
|
|||||||
BytesRx int64
|
BytesRx int64
|
||||||
Latency time.Duration
|
Latency time.Duration
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
|
SSHHostKey []byte
|
||||||
routes map[string]struct{}
|
routes map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||||
|
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peerPubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.SSHHostKey = sshHostKey
|
||||||
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// FinishPeerListModifications this event invoke the notification
|
// FinishPeerListModifications this event invoke the notification
|
||||||
func (d *Status) FinishPeerListModifications() {
|
func (d *Status) FinishPeerListModifications() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -50,6 +51,12 @@ type ConfigInput struct {
|
|||||||
StateFilePath string
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
EnableSSHRoot *bool
|
||||||
|
EnableSSHSFTP *bool
|
||||||
|
EnableSSHLocalPortForwarding *bool
|
||||||
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
SSHJWTCacheTTL *int
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
RosenpassEnabled *bool
|
RosenpassEnabled *bool
|
||||||
@@ -94,6 +101,12 @@ type Config struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
EnableSSHRoot *bool
|
||||||
|
EnableSSHSFTP *bool
|
||||||
|
EnableSSHLocalPortForwarding *bool
|
||||||
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
SSHJWTCacheTTL *int
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -153,21 +166,28 @@ func getConfigDir() (string, error) {
|
|||||||
if ConfigDirOverride != "" {
|
if ConfigDirOverride != "" {
|
||||||
return ConfigDirOverride, nil
|
return ConfigDirOverride, nil
|
||||||
}
|
}
|
||||||
configDir, err := os.UserConfigDir()
|
|
||||||
|
base, err := baseConfigDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
configDir = filepath.Join(configDir, "netbird")
|
configDir := filepath.Join(base, "netbird")
|
||||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return configDir, nil
|
return configDir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func baseConfigDir() (string, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if u, err := user.Current(); err == nil && u.HomeDir != "" {
|
||||||
|
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return os.UserConfigDir()
|
||||||
|
}
|
||||||
|
|
||||||
func getConfigDirForUser(username string) (string, error) {
|
func getConfigDirForUser(username string) (string, error) {
|
||||||
if ConfigDirOverride != "" {
|
if ConfigDirOverride != "" {
|
||||||
return ConfigDirOverride, nil
|
return ConfigDirOverride, nil
|
||||||
@@ -376,6 +396,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||||
|
if *input.EnableSSHRoot {
|
||||||
|
log.Infof("enabling SSH root login")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH root login")
|
||||||
|
}
|
||||||
|
config.EnableSSHRoot = input.EnableSSHRoot
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||||
|
if *input.EnableSSHSFTP {
|
||||||
|
log.Infof("enabling SSH SFTP subsystem")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH SFTP subsystem")
|
||||||
|
}
|
||||||
|
config.EnableSSHSFTP = input.EnableSSHSFTP
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||||
|
if *input.EnableSSHLocalPortForwarding {
|
||||||
|
log.Infof("enabling SSH local port forwarding")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH local port forwarding")
|
||||||
|
}
|
||||||
|
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||||
|
if *input.EnableSSHRemotePortForwarding {
|
||||||
|
log.Infof("enabling SSH remote port forwarding")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH remote port forwarding")
|
||||||
|
}
|
||||||
|
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||||
|
if *input.DisableSSHAuth {
|
||||||
|
log.Infof("disabling SSH authentication")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling SSH authentication")
|
||||||
|
}
|
||||||
|
config.DisableSSHAuth = input.DisableSSHAuth
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||||
|
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||||
|
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||||
|
|||||||
@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLoginHint retrieves the email from the active profile to use as login_hint.
|
||||||
|
func GetLoginHint() string {
|
||||||
|
pm := NewProfileManager()
|
||||||
|
activeProf, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get active profile for login hint: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return profileState.Email
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServiceManager struct {
|
type ServiceManager struct {
|
||||||
|
profilesDir string // If set, overrides ConfigDirOverride for profile operations
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
||||||
@@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
|||||||
return &ServiceManager{}
|
return &ServiceManager{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
|
||||||
|
// This allows setting the profiles directory without modifying the global ConfigDirOverride
|
||||||
|
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
|
||||||
|
if defaultConfigPath != "" {
|
||||||
|
DefaultConfigPath = defaultConfigPath
|
||||||
|
}
|
||||||
|
return &ServiceManager{
|
||||||
|
profilesDir: profilesDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
|
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
|
||||||
|
|
||||||
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
|
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
|
||||||
@@ -240,7 +252,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||||
configDir, err := getConfigDirForUser(username)
|
configDir, err := s.getConfigDir(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get config directory: %w", err)
|
return fmt.Errorf("failed to get config directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -270,7 +282,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||||
configDir, err := getConfigDirForUser(username)
|
configDir, err := s.getConfigDir(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get config directory: %w", err)
|
return fmt.Errorf("failed to get config directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -302,7 +314,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
||||||
configDir, err := getConfigDirForUser(username)
|
configDir, err := s.getConfigDir(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -361,7 +373,7 @@ func (s *ServiceManager) GetStatePath() string {
|
|||||||
return defaultStatePath
|
return defaultStatePath
|
||||||
}
|
}
|
||||||
|
|
||||||
configDir, err := getConfigDirForUser(activeProf.Username)
|
configDir, err := s.getConfigDir(activeProf.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
|
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
|
||||||
return defaultStatePath
|
return defaultStatePath
|
||||||
@@ -369,3 +381,12 @@ func (s *ServiceManager) GetStatePath() string {
|
|||||||
|
|
||||||
return filepath.Join(configDir, activeProf.Name+".state.json")
|
return filepath.Join(configDir, activeProf.Name+".state.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
|
||||||
|
func (s *ServiceManager) getConfigDir(username string) (string, error) {
|
||||||
|
if s.profilesDir != "" {
|
||||||
|
return s.profilesDir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return getConfigDirForUser(username)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
@@ -39,6 +38,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
|
|||||||
218
client/internal/sleep/detector_darwin.go
Normal file
218
client/internal/sleep/detector_darwin.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||||
|
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||||
|
#include <IOKit/IOMessage.h>
|
||||||
|
#include <CoreFoundation/CoreFoundation.h>
|
||||||
|
|
||||||
|
extern void sleepCallbackBridge();
|
||||||
|
extern void poweredOnCallbackBridge();
|
||||||
|
extern void suspendedCallbackBridge();
|
||||||
|
extern void resumedCallbackBridge();
|
||||||
|
|
||||||
|
|
||||||
|
// C global variables for IOKit state
|
||||||
|
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||||
|
static io_object_t g_notifierObject = 0;
|
||||||
|
static io_object_t g_generalInterestNotifier = 0;
|
||||||
|
static io_connect_t g_rootPort = 0;
|
||||||
|
static CFRunLoopRef g_runLoop = NULL;
|
||||||
|
|
||||||
|
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||||
|
switch (messageType) {
|
||||||
|
case kIOMessageSystemWillSleep:
|
||||||
|
sleepCallbackBridge();
|
||||||
|
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||||
|
break;
|
||||||
|
case kIOMessageSystemHasPoweredOn:
|
||||||
|
poweredOnCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsSuspended:
|
||||||
|
suspendedCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsResumed:
|
||||||
|
resumedCallbackBridge();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNotifications() {
|
||||||
|
g_rootPort = IORegisterForSystemPower(
|
||||||
|
NULL,
|
||||||
|
&g_notifyPortRef,
|
||||||
|
(IOServiceInterestCallback)sleepCallback,
|
||||||
|
&g_notifierObject
|
||||||
|
);
|
||||||
|
|
||||||
|
if (g_rootPort == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
g_runLoop = CFRunLoopGetCurrent();
|
||||||
|
CFRunLoopRun();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unregisterNotifications() {
|
||||||
|
CFRunLoopRemoveSource(g_runLoop,
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
IODeregisterForSystemPower(&g_notifierObject);
|
||||||
|
IOServiceClose(g_rootPort);
|
||||||
|
IONotificationPortDestroy(g_notifyPortRef);
|
||||||
|
CFRunLoopStop(g_runLoop);
|
||||||
|
|
||||||
|
g_notifyPortRef = NULL;
|
||||||
|
g_notifierObject = 0;
|
||||||
|
g_rootPort = 0;
|
||||||
|
g_runLoop = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serviceRegistry = make(map[*Detector]struct{})
|
||||||
|
serviceRegistryMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
//export sleepCallbackBridge
|
||||||
|
func sleepCallbackBridge() {
|
||||||
|
log.Info("sleepCallbackBridge event triggered")
|
||||||
|
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeSleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export resumedCallbackBridge
|
||||||
|
func resumedCallbackBridge() {
|
||||||
|
log.Info("resumedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export suspendedCallbackBridge
|
||||||
|
func suspendedCallbackBridge() {
|
||||||
|
log.Info("suspendedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export poweredOnCallbackBridge
|
||||||
|
func poweredOnCallbackBridge() {
|
||||||
|
log.Info("poweredOnCallbackBridge event triggered")
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeWakeUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
callback func(event EventType)
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector() (*Detector, error) {
|
||||||
|
return &Detector{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Register(callback func(event EventType)) error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := serviceRegistry[d]; exists {
|
||||||
|
return fmt.Errorf("detector service already registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
d.callback = callback
|
||||||
|
|
||||||
|
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
|
||||||
|
// CFRunLoop must run on a single fixed OS thread
|
||||||
|
go func() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.registerNotifications()
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Info("sleep detection service started on macOS")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||||
|
// and the runloop is stopped and cleaned up.
|
||||||
|
func (d *Detector) Deregister() error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
_, exists := serviceRegistry[d]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel and remove this detector
|
||||||
|
d.cancel()
|
||||||
|
delete(serviceRegistry, d)
|
||||||
|
|
||||||
|
// If other Detectors still exist, leave IOKit running
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service stopping (deregister)")
|
||||||
|
|
||||||
|
// Deregister IOKit notifications, stop runloop, and free resources
|
||||||
|
C.unregisterNotifications()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) triggerCallback(event EventType) {
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
|
timeout := time.NewTimer(500 * time.Millisecond)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
cb := d.callback
|
||||||
|
go func(callback func(event EventType)) {
|
||||||
|
log.Info("sleep detection event fired")
|
||||||
|
callback(event)
|
||||||
|
close(doneChan)
|
||||||
|
}(cb)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
case <-timeout.C:
|
||||||
|
log.Warnf("sleep callback timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
9
client/internal/sleep/detector_notsupported.go
Normal file
9
client/internal/sleep/detector_notsupported.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func NewDetector() (detector, error) {
|
||||||
|
return nil, fmt.Errorf("sleep not supported on this platform")
|
||||||
|
}
|
||||||
37
client/internal/sleep/service.go
Normal file
37
client/internal/sleep/service.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package sleep
|
||||||
|
|
||||||
|
var (
|
||||||
|
EventTypeUnknown EventType = 0
|
||||||
|
EventTypeSleep EventType = 1
|
||||||
|
EventTypeWakeUp EventType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type EventType int
|
||||||
|
|
||||||
|
type detector interface {
|
||||||
|
Register(callback func(eventType EventType)) error
|
||||||
|
Deregister() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
detector detector
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() (*Service, error) {
|
||||||
|
d, err := NewDetector()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
detector: d,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Register(callback func(eventType EventType)) error {
|
||||||
|
return s.detector.Register(callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Deregister() error {
|
||||||
|
return s.detector.Deregister()
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package templates
|
||||||
|
|
||||||
|
import (
|
||||||
|
"html/template"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]string
|
||||||
|
outputFile string
|
||||||
|
expectedTitle string
|
||||||
|
expectedInContent []string
|
||||||
|
notExpectedInContent []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error_state",
|
||||||
|
data: map[string]string{
|
||||||
|
"Error": "authentication failed: invalid state",
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-error.html",
|
||||||
|
expectedTitle: "Login Failed",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"authentication failed: invalid state",
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success_state",
|
||||||
|
data: map[string]string{
|
||||||
|
// No error field means success
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-success.html",
|
||||||
|
expectedTitle: "Login Successful",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird. You can now close this window.",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_state_timeout",
|
||||||
|
data: map[string]string{
|
||||||
|
"Error": "authentication timeout: request expired after 5 minutes",
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-timeout.html",
|
||||||
|
expectedTitle: "Login Failed",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"authentication timeout: request expired after 5 minutes",
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temp directory for this test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, tt.outputFile)
|
||||||
|
|
||||||
|
// Create output file
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the template
|
||||||
|
if err := tmpl.Execute(file, tt.data); err != nil {
|
||||||
|
file.Close()
|
||||||
|
t.Fatalf("Failed to execute template: %v", err)
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
|
||||||
|
t.Logf("Generated test output: %s", outputPath)
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(content)
|
||||||
|
|
||||||
|
// Verify file has content
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Output file is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify basic HTML structure
|
||||||
|
basicElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"NetBird",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range basicElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected title
|
||||||
|
if !contains(contentStr, tt.expectedTitle) {
|
||||||
|
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected content is present
|
||||||
|
for _, expected := range tt.expectedInContent {
|
||||||
|
if !contains(contentStr, expected) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify unexpected content is not present
|
||||||
|
for _, notExpected := range tt.notExpectedInContent {
|
||||||
|
if contains(contentStr, notExpected) {
|
||||||
|
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
|
||||||
|
// Test that the template can be parsed without errors
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Template parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty data
|
||||||
|
t.Run("empty_data", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "empty-data.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
if err := tmpl.Execute(file, nil); err != nil {
|
||||||
|
t.Errorf("Template execution with nil data failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with error data
|
||||||
|
t.Run("with_error", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "with-error.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data := map[string]string{
|
||||||
|
"Error": "test error message",
|
||||||
|
}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Errorf("Template execution with error data failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
|
||||||
|
// Test that the template contains expected elements
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Template parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("success_content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "success.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data := map[string]string{}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Fatalf("Template execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file and verify it contains expected content
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for success indicators
|
||||||
|
contentStr := string(content)
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Generated HTML is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic HTML structure checks
|
||||||
|
requiredElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"Login Successful",
|
||||||
|
"NetBird",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range requiredElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error_content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "error.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
errorMsg := "test error message"
|
||||||
|
data := map[string]string{
|
||||||
|
"Error": errorMsg,
|
||||||
|
}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Fatalf("Template execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file and verify it contains expected content
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for error indicators
|
||||||
|
contentStr := string(content)
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Generated HTML is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic HTML structure checks
|
||||||
|
requiredElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"Login Failed",
|
||||||
|
errorMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range requiredElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||||
|
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
35
client/internal/updatemanager/doc.go
Normal file
35
client/internal/updatemanager/doc.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
// Package updatemanager provides automatic update management for the NetBird client.
|
||||||
|
// It monitors for new versions, handles update triggers from management server directives,
|
||||||
|
// and orchestrates the download and installation of client updates.
|
||||||
|
//
|
||||||
|
// # Overview
|
||||||
|
//
|
||||||
|
// The update manager operates as a background service that continuously monitors for
|
||||||
|
// available updates and automatically initiates the update process when conditions are met.
|
||||||
|
// It integrates with the installer package to perform the actual installation.
|
||||||
|
//
|
||||||
|
// # Update Flow
|
||||||
|
//
|
||||||
|
// The complete update process follows these steps:
|
||||||
|
//
|
||||||
|
// 1. Manager receives update directive via SetVersion() or detects new version
|
||||||
|
// 2. Manager validates update should proceed (version comparison, rate limiting)
|
||||||
|
// 3. Manager publishes "updating" event to status recorder
|
||||||
|
// 4. Manager persists UpdateState to track update attempt
|
||||||
|
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
|
||||||
|
// 6. Manager triggers installation via installer.RunInstallation()
|
||||||
|
// 7. Installer package handles the actual installation process
|
||||||
|
// 8. On next startup, CheckUpdateSuccess() verifies update completion
|
||||||
|
// 9. Manager publishes success/failure event to status recorder
|
||||||
|
// 10. Manager cleans up UpdateState
|
||||||
|
//
|
||||||
|
// # State Management
|
||||||
|
//
|
||||||
|
// Update state is persisted across restarts to track update attempts:
|
||||||
|
//
|
||||||
|
// - PreUpdateVersion: Version before update attempt
|
||||||
|
// - TargetVersion: Version attempting to update to
|
||||||
|
//
|
||||||
|
// This enables verification of successful updates and appropriate user notification
|
||||||
|
// after the client restarts with the new version.
|
||||||
|
package updatemanager
|
||||||
138
client/internal/updatemanager/downloader/downloader.go
Normal file
138
client/internal/updatemanager/downloader/downloader.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
userAgent = "NetBird agent installer/%s"
|
||||||
|
DefaultRetryDelay = 3 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
|
||||||
|
log.Debugf("starting download from %s", url)
|
||||||
|
|
||||||
|
out, err := os.Create(dstFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := out.Close(); cerr != nil {
|
||||||
|
log.Warnf("error closing file %q: %v", dstFile, cerr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// First attempt
|
||||||
|
err = downloadToFileOnce(ctx, url, out)
|
||||||
|
if err == nil {
|
||||||
|
log.Infof("successfully downloaded file to %s", dstFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If retryDelay is 0, don't retry
|
||||||
|
if retryDelay == 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
|
||||||
|
|
||||||
|
// Sleep before retry
|
||||||
|
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
|
||||||
|
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate file before retry
|
||||||
|
if err := out.Truncate(0); err != nil {
|
||||||
|
return fmt.Errorf("failed to truncate file on retry: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := out.Seek(0, 0); err != nil {
|
||||||
|
return fmt.Errorf("failed to seek to beginning of file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second attempt
|
||||||
|
if err := downloadToFileOnce(ctx, url, out); err != nil {
|
||||||
|
return fmt.Errorf("download failed after retry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("successfully downloaded file to %s", dstFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add User-Agent header
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := resp.Body.Close(); cerr != nil {
|
||||||
|
log.Warnf("error closing response body: %v", cerr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add User-Agent header
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to perform HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := resp.Body.Close(); cerr != nil {
|
||||||
|
log.Warnf("error closing response body: %v", cerr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||||
|
return fmt.Errorf("failed to write response body to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepWithContext(ctx context.Context, duration time.Duration) error {
|
||||||
|
select {
|
||||||
|
case <-time.After(duration):
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
199
client/internal/updatemanager/downloader/downloader_test.go
Normal file
199
client/internal/updatemanager/downloader/downloader_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
retryDelay = 100 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadToFile_Success(t *testing.T) {
|
||||||
|
// Create a test server that responds successfully
|
||||||
|
content := "test file content"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(content))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file
|
||||||
|
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the file content
|
||||||
|
data, err := os.ReadFile(dstFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read downloaded file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != content {
|
||||||
|
t.Errorf("expected content %q, got %q", content, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
|
||||||
|
content := "test file content after retry"
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that fails on first attempt, succeeds on second
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempt := attemptCount.Add(1)
|
||||||
|
if attempt == 1 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(content))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file (should succeed after retry)
|
||||||
|
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
|
||||||
|
t.Fatalf("expected no error after retry, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the file content
|
||||||
|
data, err := os.ReadFile(dstFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read downloaded file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != content {
|
||||||
|
t.Errorf("expected content %q, got %q", content, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it took 2 attempts
|
||||||
|
if attemptCount.Load() != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that always fails
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount.Add(1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("error"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file (should fail after retry)
|
||||||
|
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
|
||||||
|
t.Fatal("expected error after retry, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it tried 2 times
|
||||||
|
if attemptCount.Load() != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that always fails
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount.Add(1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Create a context that will be cancelled during retry delay
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Cancel after a short delay (during the retry sleep)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Download the file (should fail due to context cancellation during retry)
|
||||||
|
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error due to context cancellation, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have only made 1 attempt (cancelled during retry delay)
|
||||||
|
if attemptCount.Load() != 1 {
|
||||||
|
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_InvalidURL(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid URL, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_InvalidDestination(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("test"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Use an invalid destination path
|
||||||
|
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid destination, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_NoRetry(t *testing.T) {
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that always fails
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount.Add(1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("error"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file with retryDelay = 0 (should not retry)
|
||||||
|
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it only made 1 attempt (no retry)
|
||||||
|
if attemptCount.Load() != 1 {
|
||||||
|
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
func UpdaterBinaryNameWithoutExtension() string {
|
||||||
|
return updaterBinary
|
||||||
|
}
|
||||||
11
client/internal/updatemanager/installer/binary_windows.go
Normal file
11
client/internal/updatemanager/installer/binary_windows.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdaterBinaryNameWithoutExtension() string {
|
||||||
|
ext := filepath.Ext(updaterBinary)
|
||||||
|
return strings.TrimSuffix(updaterBinary, ext)
|
||||||
|
}
|
||||||
111
client/internal/updatemanager/installer/doc.go
Normal file
111
client/internal/updatemanager/installer/doc.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
// Package installer provides functionality for managing NetBird application
|
||||||
|
// updates and installations across Windows, macOS. It handles
|
||||||
|
// the complete update lifecycle including artifact download, cryptographic verification,
|
||||||
|
// installation execution, process management, and result reporting.
|
||||||
|
//
|
||||||
|
// # Architecture
|
||||||
|
//
|
||||||
|
// The installer package uses a two-process architecture to enable self-updates:
|
||||||
|
//
|
||||||
|
// 1. Service Process: The main NetBird daemon process that initiates updates
|
||||||
|
// 2. Updater Process: A detached child process that performs the actual installation
|
||||||
|
//
|
||||||
|
// This separation is critical because:
|
||||||
|
// - The service binary cannot update itself while running
|
||||||
|
// - The installer (EXE/MSI/PKG) will terminate the service during installation
|
||||||
|
// - The updater process survives service termination and restarts it after installation
|
||||||
|
// - Results can be communicated back to the service after it restarts
|
||||||
|
//
|
||||||
|
// # Update Flow
|
||||||
|
//
|
||||||
|
// Service Process (RunInstallation):
|
||||||
|
//
|
||||||
|
// 1. Validates target version format (semver)
|
||||||
|
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
|
||||||
|
// 3. Downloads installer file from GitHub releases (if applicable)
|
||||||
|
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
|
||||||
|
// launching updater)
|
||||||
|
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
|
||||||
|
// 6. Launches updater process with detached mode:
|
||||||
|
// - --temp-dir: Temporary directory path
|
||||||
|
// - --service-dir: Service installation directory
|
||||||
|
// - --installer-file: Path to downloaded installer (if applicable)
|
||||||
|
// - --dry-run: Optional flag to test without actually installing
|
||||||
|
// 7. Service process continues running (will be terminated by installer later)
|
||||||
|
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
|
||||||
|
//
|
||||||
|
// Updater Process (Setup):
|
||||||
|
//
|
||||||
|
// 1. Receives parameters from service via command-line arguments
|
||||||
|
// 2. Runs installer with appropriate silent/quiet flags:
|
||||||
|
// - Windows EXE: installer.exe /S
|
||||||
|
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
|
||||||
|
// - macOS PKG: installer -pkg installer.pkg -target /
|
||||||
|
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
|
||||||
|
// 3. Installer terminates daemon and UI processes
|
||||||
|
// 4. Installer replaces binaries with new version
|
||||||
|
// 5. Updater waits for installer to complete
|
||||||
|
// 6. Updater restarts daemon:
|
||||||
|
// - Windows: netbird.exe service start
|
||||||
|
// - macOS/Linux: netbird service start
|
||||||
|
// 7. Updater restarts UI:
|
||||||
|
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
|
||||||
|
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
|
||||||
|
// - Linux: Not implemented (UI typically auto-starts)
|
||||||
|
// 8. Updater writes result.json with success/error status
|
||||||
|
// 9. Updater process exits
|
||||||
|
//
|
||||||
|
// # Result Communication
|
||||||
|
//
|
||||||
|
// The ResultHandler (result.go) manages communication between updater and service:
|
||||||
|
//
|
||||||
|
// Result Structure:
|
||||||
|
//
|
||||||
|
// type Result struct {
|
||||||
|
// Success bool // true if installation succeeded
|
||||||
|
// Error string // error message if Success is false
|
||||||
|
// ExecutedAt time.Time // when installation completed
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Result files are automatically cleaned up after being read.
|
||||||
|
//
|
||||||
|
// # File Locations
|
||||||
|
//
|
||||||
|
// Temporary Directory (platform-specific):
|
||||||
|
//
|
||||||
|
// Windows:
|
||||||
|
// - Path: %ProgramData%\Netbird\tmp-install
|
||||||
|
// - Example: C:\ProgramData\Netbird\tmp-install
|
||||||
|
//
|
||||||
|
// macOS:
|
||||||
|
// - Path: /var/lib/netbird/tmp-install
|
||||||
|
// - Requires root permissions
|
||||||
|
//
|
||||||
|
// Files created during installation:
|
||||||
|
//
|
||||||
|
// tmp-install/
|
||||||
|
// installer.log
|
||||||
|
// updater[.exe] # Copy of service binary
|
||||||
|
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
|
||||||
|
// result.json # Installation result
|
||||||
|
// msi.log # MSI verbose log (Windows MSI only)
|
||||||
|
//
|
||||||
|
// # API Reference
|
||||||
|
//
|
||||||
|
// # Cleanup
|
||||||
|
//
|
||||||
|
// CleanUpInstallerFiles() removes temporary files after successful installation:
|
||||||
|
// - Downloaded installer files (*.exe, *.msi, *.pkg)
|
||||||
|
// - Updater binary copy
|
||||||
|
// - Does NOT remove result.json (cleaned by ResultHandler after read)
|
||||||
|
// - Does NOT remove msi.log (kept for debugging)
|
||||||
|
//
|
||||||
|
// # Dry-Run Mode
|
||||||
|
//
|
||||||
|
// Dry-run mode allows testing the update process without actually installing:
|
||||||
|
//
|
||||||
|
// Enable via environment variable:
|
||||||
|
//
|
||||||
|
// export NB_AUTO_UPDATE_DRY_RUN=true
|
||||||
|
// netbird service install-update 0.29.0
|
||||||
|
package installer
|
||||||
50
client/internal/updatemanager/installer/installer.go
Normal file
50
client/internal/updatemanager/installer/installer.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build !windows && !darwin
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
updaterBinary = "updater"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Installer struct {
|
||||||
|
tempDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New used by the service
|
||||||
|
func New() *Installer {
|
||||||
|
return &Installer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
|
||||||
|
func NewWithDir(tempDir string) *Installer {
|
||||||
|
return &Installer{
|
||||||
|
tempDir: tempDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) TempDir() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Installer) LogFiles() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) CleanUpInstallerFiles() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
|
||||||
|
return fmt.Errorf("unsupported platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||||
|
// This will be run by the updater process
|
||||||
|
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
|
||||||
|
return fmt.Errorf("unsupported platform")
|
||||||
|
}
|
||||||
293
client/internal/updatemanager/installer/installer_common.go
Normal file
293
client/internal/updatemanager/installer/installer_common.go
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
goversion "github.com/hashicorp/go-version"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Installer struct {
|
||||||
|
tempDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New used by the service
|
||||||
|
func New() *Installer {
|
||||||
|
return &Installer{
|
||||||
|
tempDir: defaultTempDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
|
||||||
|
func NewWithDir(tempDir string) *Installer {
|
||||||
|
return &Installer{
|
||||||
|
tempDir: tempDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunInstallation starts the updater process to run the installation
|
||||||
|
// This will run by the original service process
|
||||||
|
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
|
||||||
|
resultHandler := NewResultHandler(u.tempDir)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
|
||||||
|
log.Errorf("failed to write error result: %v", writeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := validateTargetVersion(targetVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.mkTempDir(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var installerFile string
|
||||||
|
// Download files only when not using any third-party store
|
||||||
|
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
|
||||||
|
log.Infof("download installer")
|
||||||
|
var err error
|
||||||
|
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to download installer: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create artifact verify: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
|
||||||
|
log.Errorf("artifact verification error: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("running installer")
|
||||||
|
updaterPath, err := u.copyUpdater()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// the directory where the service has been installed
|
||||||
|
workspace, err := getServiceDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"--temp-dir", u.tempDir,
|
||||||
|
"--service-dir", workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDryRunEnabled() {
|
||||||
|
args = append(args, "--dry-run=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if installerFile != "" {
|
||||||
|
args = append(args, "--installer-file", installerFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateCmd := exec.Command(updaterPath, args...)
|
||||||
|
log.Infof("starting updater process: %s", updateCmd.String())
|
||||||
|
|
||||||
|
// Configure the updater to run in a separate session/process group
|
||||||
|
// so it survives the parent daemon being stopped
|
||||||
|
setUpdaterProcAttr(updateCmd)
|
||||||
|
|
||||||
|
// Start the updater process asynchronously
|
||||||
|
if err := updateCmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pid := updateCmd.Process.Pid
|
||||||
|
log.Infof("updater started with PID %d", pid)
|
||||||
|
|
||||||
|
// Release the process so the OS can fully detach it
|
||||||
|
if err := updateCmd.Process.Release(); err != nil {
|
||||||
|
log.Warnf("failed to release updater process: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanUpInstallerFiles
|
||||||
|
// - the installer file (pkg, exe, msi)
|
||||||
|
// - the selfcopy updater.exe
|
||||||
|
func (u *Installer) CleanUpInstallerFiles() error {
|
||||||
|
// Check if tempDir exists
|
||||||
|
info, err := os.Stat(u.tempDir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(u.tempDir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name := entry.Name()
|
||||||
|
for _, ext := range binaryExtensions {
|
||||||
|
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
|
||||||
|
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return merr.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
|
||||||
|
fileURL := urlWithVersionArch(installerType, targetVersion)
|
||||||
|
|
||||||
|
// Clean up temp directory on error
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
if err := os.RemoveAll(u.tempDir); err != nil {
|
||||||
|
log.Errorf("error cleaning up temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
fileName := path.Base(fileURL)
|
||||||
|
if fileName == "." || fileName == "/" || fileName == "" {
|
||||||
|
return "", fmt.Errorf("invalid file URL: %s", fileURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputFilePath := filepath.Join(u.tempDir, fileName)
|
||||||
|
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
success = true
|
||||||
|
return outputFilePath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) TempDir() string {
|
||||||
|
return u.tempDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) mkTempDir() error {
|
||||||
|
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
|
||||||
|
log.Debugf("failed to create tempdir: %s", u.tempDir)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) copyUpdater() (string, error) {
|
||||||
|
src, err := getServiceBinary()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get updater binary: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst := filepath.Join(u.tempDir, updaterBinary)
|
||||||
|
if err := copyFile(src, dst); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to copy updater binary: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Chmod(dst, 0o755); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dst, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateTargetVersion(targetVersion string) error {
|
||||||
|
if targetVersion == "" {
|
||||||
|
return fmt.Errorf("target version cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := goversion.NewVersion(targetVersion)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
log.Infof("copying %s to %s", src, dst)
|
||||||
|
in, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open source: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := in.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close source file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
out, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create destination: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := out.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close destination file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, in); err != nil {
|
||||||
|
return fmt.Errorf("copy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceDir() (string, error) {
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Dir(exePath), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceBinary() (string, error) {
|
||||||
|
return os.Executable()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDryRunEnabled() bool {
|
||||||
|
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (u *Installer) LogFiles() []string {
|
||||||
|
return []string{
|
||||||
|
filepath.Join(u.tempDir, LogFile),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (u *Installer) LogFiles() []string {
|
||||||
|
return []string{
|
||||||
|
filepath.Join(u.tempDir, msiLogFile),
|
||||||
|
filepath.Join(u.tempDir, LogFile),
|
||||||
|
}
|
||||||
|
}
|
||||||
238
client/internal/updatemanager/installer/installer_run_darwin.go
Normal file
238
client/internal/updatemanager/installer/installer_run_darwin.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
daemonName = "netbird"
|
||||||
|
updaterBinary = "updater"
|
||||||
|
uiBinary = "/Applications/NetBird.app"
|
||||||
|
|
||||||
|
defaultTempDir = "/var/lib/netbird/tmp-install"
|
||||||
|
|
||||||
|
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
binaryExtensions = []string{"pkg"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||||
|
// This will be run by the updater process
|
||||||
|
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
|
||||||
|
resultHandler := NewResultHandler(u.tempDir)
|
||||||
|
|
||||||
|
// Always ensure daemon and UI are restarted after setup
|
||||||
|
defer func() {
|
||||||
|
log.Infof("write out result")
|
||||||
|
var err error
|
||||||
|
if resultErr == nil {
|
||||||
|
err = resultHandler.WriteSuccess()
|
||||||
|
} else {
|
||||||
|
err = resultHandler.WriteErr(resultErr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write update result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip service restart if dry-run mode is enabled
|
||||||
|
if dryRun {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting daemon back")
|
||||||
|
if err := u.startDaemon(daemonFolder); err != nil {
|
||||||
|
log.Errorf("failed to start daemon: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting UI back")
|
||||||
|
if err := u.startUIAsUser(); err != nil {
|
||||||
|
log.Errorf("failed to start UI: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
if dryRun {
|
||||||
|
time.Sleep(7 * time.Second)
|
||||||
|
log.Infof("dry-run mode enabled, skipping actual installation")
|
||||||
|
resultErr = fmt.Errorf("dry-run mode enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch TypeOfInstaller(ctx) {
|
||||||
|
case TypePKG:
|
||||||
|
resultErr = u.installPkgFile(ctx, installerFile)
|
||||||
|
case TypeHomebrew:
|
||||||
|
resultErr = u.updateHomeBrew(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startDaemon(daemonFolder string) error {
|
||||||
|
log.Infof("starting netbird service")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Infof("netbird service started successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startUIAsUser() error {
|
||||||
|
log.Infof("starting netbird-ui: %s", uiBinary)
|
||||||
|
|
||||||
|
// Get the current console user
|
||||||
|
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get console user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := strings.TrimSpace(string(output))
|
||||||
|
if username == "" || username == "root" {
|
||||||
|
return fmt.Errorf("no active user session found")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting UI for user: %s", username)
|
||||||
|
|
||||||
|
// Get user's UID
|
||||||
|
userInfo, err := user.Lookup(username)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to lookup user %s: %w", username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the UI process as the console user using launchctl
|
||||||
|
// This ensures the app runs in the user's context with proper GUI access
|
||||||
|
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
|
||||||
|
log.Infof("launchCmd: %s", launchCmd.String())
|
||||||
|
// Set the user's home directory for proper macOS app behavior
|
||||||
|
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
|
||||||
|
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
|
||||||
|
|
||||||
|
if err := launchCmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start UI process: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the process so it can run independently
|
||||||
|
if err := launchCmd.Process.Release(); err != nil {
|
||||||
|
log.Warnf("failed to release UI process: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("netbird-ui started successfully for user %s", username)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
|
||||||
|
log.Infof("installing pkg file: %s", path)
|
||||||
|
|
||||||
|
// Kill any existing UI processes before installation
|
||||||
|
// This ensures the postinstall script's "open $APP" will start the new version
|
||||||
|
u.killUI()
|
||||||
|
|
||||||
|
volume := "/"
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("error running pkg file: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("installer started with PID %d", cmd.Process.Pid)
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("error running pkg file: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("pkg file installed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) updateHomeBrew(ctx context.Context) error {
|
||||||
|
log.Infof("updating homebrew")
|
||||||
|
|
||||||
|
// Kill any existing UI processes before upgrade
|
||||||
|
// This ensures the new version will be started after upgrade
|
||||||
|
u.killUI()
|
||||||
|
|
||||||
|
// Homebrew must be run as a non-root user
|
||||||
|
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
|
||||||
|
// Check both Apple Silicon and Intel Mac paths
|
||||||
|
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
|
||||||
|
brewBinPath := "/opt/homebrew/bin/brew"
|
||||||
|
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
|
||||||
|
// Try Intel Mac path
|
||||||
|
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
|
||||||
|
brewBinPath = "/usr/local/bin/brew"
|
||||||
|
}
|
||||||
|
|
||||||
|
fileInfo, err := os.Stat(brewTapPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting homebrew installation path info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get username from UID
|
||||||
|
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error looking up brew installer user: %w", err)
|
||||||
|
}
|
||||||
|
userName := brewUser.Username
|
||||||
|
// Get user HOME, required for brew to run correctly
|
||||||
|
// https://github.com/Homebrew/brew/issues/15833
|
||||||
|
homeDir := brewUser.HomeDir
|
||||||
|
|
||||||
|
// Check if netbird-ui is installed (must run as the brew user, not root)
|
||||||
|
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
|
||||||
|
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
|
||||||
|
uiInstalled := checkUICmd.Run() == nil
|
||||||
|
|
||||||
|
// Homebrew does not support installing specific versions
|
||||||
|
// Thus it will always update to latest and ignore targetVersion
|
||||||
|
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
|
||||||
|
if uiInstalled {
|
||||||
|
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
|
||||||
|
cmd.Env = append(os.Environ(), "HOME="+homeDir)
|
||||||
|
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("homebrew updated successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) killUI() {
|
||||||
|
log.Infof("killing existing netbird-ui processes")
|
||||||
|
cmd := exec.Command("pkill", "-x", "netbird-ui")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
// pkill returns exit code 1 if no processes matched, which is fine
|
||||||
|
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
|
||||||
|
} else {
|
||||||
|
log.Infof("netbird-ui processes killed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func urlWithVersionArch(_ Type, version string) string {
|
||||||
|
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
|
||||||
|
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||||
|
}
|
||||||
213
client/internal/updatemanager/installer/installer_run_windows.go
Normal file
213
client/internal/updatemanager/installer/installer_run_windows.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
daemonName = "netbird.exe"
|
||||||
|
uiName = "netbird-ui.exe"
|
||||||
|
updaterBinary = "updater.exe"
|
||||||
|
|
||||||
|
msiLogFile = "msi.log"
|
||||||
|
|
||||||
|
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||||
|
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
|
||||||
|
|
||||||
|
// for the cleanup
|
||||||
|
binaryExtensions = []string{"msi", "exe"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||||
|
// This will be run by the updater process
|
||||||
|
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
|
||||||
|
resultHandler := NewResultHandler(u.tempDir)
|
||||||
|
|
||||||
|
// Always ensure daemon and UI are restarted after setup
|
||||||
|
defer func() {
|
||||||
|
log.Infof("starting daemon back")
|
||||||
|
if err := u.startDaemon(daemonFolder); err != nil {
|
||||||
|
log.Errorf("failed to start daemon: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting UI back")
|
||||||
|
if err := u.startUIAsUser(daemonFolder); err != nil {
|
||||||
|
log.Errorf("failed to start UI: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("write out result")
|
||||||
|
var err error
|
||||||
|
if resultErr == nil {
|
||||||
|
err = resultHandler.WriteSuccess()
|
||||||
|
} else {
|
||||||
|
err = resultHandler.WriteErr(resultErr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write update result: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if dryRun {
|
||||||
|
log.Infof("dry-run mode enabled, skipping actual installation")
|
||||||
|
resultErr = fmt.Errorf("dry-run mode enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
installerType, err := typeByFileExtension(installerFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("%v", err)
|
||||||
|
resultErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
switch installerType {
|
||||||
|
case TypeExe:
|
||||||
|
log.Infof("run exe installer: %s", installerFile)
|
||||||
|
cmd = exec.CommandContext(ctx, installerFile, "/S")
|
||||||
|
default:
|
||||||
|
installerDir := filepath.Dir(installerFile)
|
||||||
|
logPath := filepath.Join(installerDir, msiLogFile)
|
||||||
|
log.Infof("run msi installer: %s", installerFile)
|
||||||
|
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Dir = filepath.Dir(installerFile)
|
||||||
|
|
||||||
|
if resultErr = cmd.Start(); resultErr != nil {
|
||||||
|
log.Errorf("error starting installer: %v", resultErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("installer started with PID %d", cmd.Process.Pid)
|
||||||
|
if resultErr = cmd.Wait(); resultErr != nil {
|
||||||
|
log.Errorf("installer process finished with error: %v", resultErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startDaemon(daemonFolder string) error {
|
||||||
|
log.Infof("starting netbird service")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Infof("netbird service started successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startUIAsUser(daemonFolder string) error {
|
||||||
|
uiPath := filepath.Join(daemonFolder, uiName)
|
||||||
|
log.Infof("starting netbird-ui: %s", uiPath)
|
||||||
|
|
||||||
|
// Get the active console session ID
|
||||||
|
sessionID := windows.WTSGetActiveConsoleSessionId()
|
||||||
|
if sessionID == 0xFFFFFFFF {
|
||||||
|
return fmt.Errorf("no active user session found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the user token for that session
|
||||||
|
var userToken windows.Token
|
||||||
|
err := windows.WTSQueryUserToken(sessionID, &userToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query user token: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := userToken.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close user token: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Duplicate the token to a primary token
|
||||||
|
var primaryToken windows.Token
|
||||||
|
err = windows.DuplicateTokenEx(
|
||||||
|
userToken,
|
||||||
|
windows.MAXIMUM_ALLOWED,
|
||||||
|
nil,
|
||||||
|
windows.SecurityImpersonation,
|
||||||
|
windows.TokenPrimary,
|
||||||
|
&primaryToken,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to duplicate token: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := primaryToken.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close token: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Prepare startup info
|
||||||
|
var si windows.StartupInfo
|
||||||
|
si.Cb = uint32(unsafe.Sizeof(si))
|
||||||
|
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
|
||||||
|
|
||||||
|
var pi windows.ProcessInformation
|
||||||
|
|
||||||
|
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to convert path to UTF16: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
|
||||||
|
|
||||||
|
err = windows.CreateProcessAsUser(
|
||||||
|
primaryToken,
|
||||||
|
nil,
|
||||||
|
cmdLine,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
|
creationFlags,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&si,
|
||||||
|
&pi,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close handles
|
||||||
|
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||||
|
log.Warnf("failed to close process handle: %v", err)
|
||||||
|
}
|
||||||
|
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||||
|
log.Warnf("failed to close thread handle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("netbird-ui started successfully in session %d", sessionID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func urlWithVersionArch(it Type, version string) string {
|
||||||
|
var url string
|
||||||
|
if it == TypeExe {
|
||||||
|
url = exeDownloadURL
|
||||||
|
} else {
|
||||||
|
url = msiDownloadURL
|
||||||
|
}
|
||||||
|
url = strings.ReplaceAll(url, "%version", version)
|
||||||
|
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user