mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
59 Commits
vk/debug/n
...
sync-clien
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72513d7522 | ||
|
|
a1f1bf1f19 | ||
|
|
b5dec3df39 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
d2e48d4f5e | ||
|
|
27dd97c9c4 | ||
|
|
e87b4ace11 | ||
|
|
a232cf614c | ||
|
|
a293f760af | ||
|
|
10e9cf8c62 | ||
|
|
7193bd2da7 | ||
|
|
52948ccd61 | ||
|
|
4b77359042 | ||
|
|
387d43bcc1 | ||
|
|
e47d815dd2 | ||
|
|
cb83b7c0d3 | ||
|
|
ddcd182859 | ||
|
|
aca0398105 | ||
|
|
02200d790b | ||
|
|
f31bba87b4 | ||
|
|
7285fef0f0 | ||
|
|
20973063d8 | ||
|
|
ba2e9b6d88 | ||
|
|
131d7a3694 | ||
|
|
290fe2d8b9 | ||
|
|
7fb1a2fe31 | ||
|
|
32146e576d | ||
|
|
1311364397 | ||
|
|
68f56b797d | ||
|
|
3351b38434 | ||
|
|
05cbead39b | ||
|
|
60f4d5f9b0 | ||
|
|
4eeb2d8deb | ||
|
|
d71a82769c | ||
|
|
0d79301141 | ||
|
|
e4b41d0ad7 | ||
|
|
9cc9462cd5 | ||
|
|
3176b53968 | ||
|
|
27957036c9 | ||
|
|
6fb568728f | ||
|
|
cc97cffff1 | ||
|
|
20f5f00635 | ||
|
|
fc141cf3a3 | ||
|
|
d0c65fa08e | ||
|
|
f241bfa339 | ||
|
|
4b2cd97d5f |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Running pre-push hook..."
|
||||||
|
if ! make lint; then
|
||||||
|
echo ""
|
||||||
|
echo "Hint: To push without verification, run:"
|
||||||
|
echo " git push --no-verify"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All checks passed!"
|
||||||
119
.github/workflows/check-license-dependencies.yml
vendored
119
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,39 +3,108 @@ name: Check License Dependencies
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- '.github/workflows/check-license-dependencies.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- '.github/workflows/check-license-dependencies.yml'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check-dependencies:
|
check-internal-dependencies:
|
||||||
|
name: Check Internal AGPL Dependencies
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Check for problematic license dependencies
|
||||||
|
run: |
|
||||||
|
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Find all directories except the problematic ones and system dirs
|
||||||
|
FOUND_ISSUES=0
|
||||||
|
while IFS= read -r dir; do
|
||||||
|
echo "=== Checking $dir ==="
|
||||||
|
# Search for problematic imports, excluding test files
|
||||||
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
|
if [ -n "$RESULTS" ]; then
|
||||||
|
echo "❌ Found problematic dependencies:"
|
||||||
|
echo "$RESULTS"
|
||||||
|
FOUND_ISSUES=1
|
||||||
|
else
|
||||||
|
echo "✓ No problematic dependencies found"
|
||||||
|
fi
|
||||||
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
|
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||||
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
echo "✅ All internal license dependencies are clean"
|
||||||
|
fi
|
||||||
|
|
||||||
|
check-external-licenses:
|
||||||
|
name: Check External GPL/AGPL Licenses
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Set up Go
|
||||||
run: |
|
uses: actions/setup-go@v5
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
with:
|
||||||
|
go-version-file: 'go.mod'
|
||||||
|
cache: true
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
- name: Install go-licenses
|
||||||
FOUND_ISSUES=0
|
run: go install github.com/google/go-licenses@v1.6.0
|
||||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
|
||||||
echo "=== Checking $dir ==="
|
- name: Check for GPL/AGPL licensed dependencies
|
||||||
# Search for problematic imports, excluding test files
|
run: |
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||||
if [ ! -z "$RESULTS" ]; then
|
echo ""
|
||||||
echo "❌ Found problematic dependencies:"
|
|
||||||
echo "$RESULTS"
|
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||||
FOUND_ISSUES=1
|
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||||
else
|
|
||||||
echo "✓ No problematic dependencies found"
|
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||||
|
echo "Found copyleft licensed dependencies:"
|
||||||
|
echo "$COPYLEFT_DEPS"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||||
|
INCOMPATIBLE=""
|
||||||
|
while IFS=',' read -r package url license; do
|
||||||
|
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||||
|
# Find ALL packages that import this GPL package using go list
|
||||||
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
|
# Check if any importer is NOT in management/signal/relay
|
||||||
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||||
|
|
||||||
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||||
|
else
|
||||||
|
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done <<< "$COPYLEFT_DEPS"
|
||||||
|
|
||||||
|
if [ -n "$INCOMPATIBLE" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||||
|
echo -e "$INCOMPATIBLE"
|
||||||
|
exit 1
|
||||||
fi
|
fi
|
||||||
done
|
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
|
||||||
echo ""
|
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
|
||||||
echo "These packages will change license and should not be imported by client or shared code"
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
echo ""
|
|
||||||
echo "✅ All license dependencies are clean"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||||
|
|||||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,13 +15,14 @@ jobs:
|
|||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
release: "14.2"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y curl pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
curl -vLO "$GO_URL"
|
curl -vLO "$GO_URL"
|
||||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|||||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
@@ -106,15 +106,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -151,15 +151,15 @@ jobs:
|
|||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
id: go-env
|
id: go-env
|
||||||
run: |
|
run: |
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
-e CONTAINER=${CONTAINER} \
|
-e CONTAINER=${CONTAINER} \
|
||||||
golang:1.23-alpine \
|
golang:1.24-alpine \
|
||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
@@ -220,15 +220,15 @@ jobs:
|
|||||||
raceFlag: "-race"
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -270,15 +270,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -321,15 +321,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -408,15 +408,16 @@ jobs:
|
|||||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -497,15 +498,15 @@ jobs:
|
|||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -561,15 +562,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
uses: android-actions/setup-android@v3
|
uses: android-actions/setup-android@v3
|
||||||
with:
|
with:
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
- name: Setup NDK
|
- name: Setup NDK
|
||||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build android netbird lib
|
- name: build android netbird lib
|
||||||
@@ -56,9 +56,9 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build iOS netbird lib
|
- name: build iOS netbird lib
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -20,7 +20,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
@@ -40,7 +40,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
@@ -67,10 +67,13 @@ jobs:
|
|||||||
- name: Install curl
|
- name: Install curl
|
||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -80,9 +83,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup MySQL privileges
|
- name: Setup MySQL privileges
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: Install golangci-lint
|
- name: Install golangci-lint
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
env:
|
env:
|
||||||
@@ -60,8 +60,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 52428800 ]; then
|
if [ ${SIZE} -gt 57671680 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,14 @@ checked out and set up:
|
|||||||
go mod tidy
|
go mod tidy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
6. Configure Git hooks for automatic linting:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make setup-hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||||
|
|
||||||
### Dev Container Support
|
### Dev Container Support
|
||||||
|
|
||||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||||
|
|||||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
.PHONY: lint lint-all lint-install setup-hooks
|
||||||
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
|
# Install golangci-lint locally if needed
|
||||||
|
$(GOLANGCI_LINT):
|
||||||
|
@echo "Installing golangci-lint..."
|
||||||
|
@mkdir -p ./bin
|
||||||
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
|
# Lint only changed files (fast, for pre-push)
|
||||||
|
lint: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on changed files..."
|
||||||
|
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||||
|
|
||||||
|
# Lint entire codebase (slow, matches CI)
|
||||||
|
lint-all: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on all files..."
|
||||||
|
@$(GOLANGCI_LINT) run --timeout=12m
|
||||||
|
|
||||||
|
# Just install the linter
|
||||||
|
lint-install: $(GOLANGCI_LINT)
|
||||||
|
|
||||||
|
# Setup git hooks for all developers
|
||||||
|
setup-hooks:
|
||||||
|
@git config core.hooksPath .githooks
|
||||||
|
@chmod +x .githooks/pre-push
|
||||||
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
@@ -4,10 +4,13 @@ package android
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -16,10 +19,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/client/net"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -62,17 +68,18 @@ type Client struct {
|
|||||||
deviceName string
|
deviceName string
|
||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
stateFile string
|
||||||
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||||
execWorkaround(androidSDKVersion)
|
execWorkaround(androidSDKVersion)
|
||||||
|
|
||||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: platformFiles.ConfigurationFilePath(),
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
uiVersion: uiVersion,
|
uiVersion: uiVersion,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
@@ -80,11 +87,12 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
|||||||
recorder: peer.NewRecorder(""),
|
recorder: peer.NewRecorder(""),
|
||||||
ctxCancelLock: &sync.Mutex{},
|
ctxCancelLock: &sync.Mutex{},
|
||||||
networkChangeListener: networkChangeListener,
|
networkChangeListener: networkChangeListener,
|
||||||
|
stateFile: platformFiles.StateFilePath(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
@@ -107,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
c.ctxCancelLock.Unlock()
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
auth := NewAuthWithConfig(ctx, cfg)
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
err = auth.login(urlOpener)
|
err = auth.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -156,6 +164,19 @@ func (c *Client) Stop() {
|
|||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) RenewTun(fd int) error {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
return fmt.Errorf("engine not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
e := c.connectClient.Engine()
|
||||||
|
if e == nil {
|
||||||
|
return fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
p.ConnStatus.String(),
|
||||||
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
@@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
log.Error("could not get route selector")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
networkArray := &NetworkArray{
|
networkArray := &NetworkArray{
|
||||||
items: make([]Network, 0),
|
items: make([]Network, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||||
|
|
||||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r := routes[0]
|
r := routes[0]
|
||||||
|
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||||
netStr := r.Network.String()
|
netStr := r.Network.String()
|
||||||
|
|
||||||
if r.IsDynamic() {
|
if r.IsDynamic() {
|
||||||
netStr = r.Domains.SafeString()
|
netStr = r.Domains.SafeString()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
network := Network{
|
network := Network{
|
||||||
Name: string(id),
|
Name: string(id),
|
||||||
Network: netStr,
|
Network: netStr,
|
||||||
Peer: peer.FQDN,
|
Peer: routePeer.FQDN,
|
||||||
Status: peer.ConnStatus.String(),
|
Status: routePeer.ConnStatus.String(),
|
||||||
|
IsSelected: routeSelector.IsSelected(id),
|
||||||
|
Domains: domains,
|
||||||
}
|
}
|
||||||
networkArray.Add(network)
|
networkArray.Add(network)
|
||||||
}
|
}
|
||||||
@@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() {
|
|||||||
c.recorder.RemoveConnectionListener()
|
c.recorder.RemoveConnectionListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) toggleRoute(command routeCommand) error {
|
||||||
|
return command.toggleRoute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
|
client := c.connectClient
|
||||||
|
if client == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := client.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := engine.GetRouteManager()
|
||||||
|
if manager == nil {
|
||||||
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SelectRoute(route string) error {
|
||||||
|
manager, err := c.getRouteManager()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) DeselectRoute(route string) error {
|
||||||
|
manager, err := c.getRouteManager()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
|
||||||
|
// with its resolved IP addresses from the provided resolvedDomains map.
|
||||||
|
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
|
||||||
|
domains := NetworkDomains{}
|
||||||
|
|
||||||
|
for _, d := range route.Domains {
|
||||||
|
networkDomain := NetworkDomain{
|
||||||
|
Address: d.SafeString(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if info, exists := resolvedDomains[d]; exists {
|
||||||
|
for _, prefix := range info.Prefixes {
|
||||||
|
networkDomain.addResolvedIP(prefix.Addr().String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domains.Add(&networkDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
func exportEnvList(list *EnvList) {
|
func exportEnvList(list *EnvList) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
|||||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(url string, userCode string)
|
||||||
OnLoginSuccess()
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login try register the client on the server
|
// Login try register the client on the server
|
||||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||||
go func() {
|
go func() {
|
||||||
err := a.login(urlOpener)
|
err := a.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resultListener.OnError(err)
|
resultListener.OnError(err)
|
||||||
} else {
|
} else {
|
||||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||||
|
|||||||
56
client/android/network_domains.go
Normal file
56
client/android/network_domains.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type ResolvedIPs struct {
|
||||||
|
resolvedIPs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||||
|
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Get(i int) (string, error) {
|
||||||
|
if i < 0 || i >= len(r.resolvedIPs) {
|
||||||
|
return "", fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return r.resolvedIPs[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Size() int {
|
||||||
|
return len(r.resolvedIPs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkDomain struct {
|
||||||
|
Address string
|
||||||
|
resolvedIPs ResolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
|
||||||
|
d.resolvedIPs.Add(resolvedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
|
||||||
|
return &d.resolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkDomains struct {
|
||||||
|
domains []*NetworkDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Add(domain *NetworkDomain) {
|
||||||
|
n.domains = append(n.domains, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
|
||||||
|
if i < 0 || i >= len(n.domains) {
|
||||||
|
return nil, fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return n.domains[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Size() int {
|
||||||
|
return len(n.domains)
|
||||||
|
}
|
||||||
@@ -3,10 +3,16 @@
|
|||||||
package android
|
package android
|
||||||
|
|
||||||
type Network struct {
|
type Network struct {
|
||||||
Name string
|
Name string
|
||||||
Network string
|
Network string
|
||||||
Peer string
|
Peer string
|
||||||
Status string
|
Status string
|
||||||
|
IsSelected bool
|
||||||
|
Domains NetworkDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Network) GetNetworkDomains() *NetworkDomains {
|
||||||
|
return &n.Domains
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkArray struct {
|
type NetworkArray struct {
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
@@ -5,6 +7,11 @@ type PeerInfo struct {
|
|||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
|
Routes PeerRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
|
||||||
|
return &p.Routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoArray is a wrapper of []PeerInfo
|
// PeerInfoArray is a wrapper of []PeerInfo
|
||||||
|
|||||||
20
client/android/peer_routes.go
Normal file
20
client/android/peer_routes.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type PeerRoutes struct {
|
||||||
|
routes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerRoutes) Get(i int) (string, error) {
|
||||||
|
if i < 0 || i >= len(p.routes) {
|
||||||
|
return "", fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return p.routes[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerRoutes) Size() int {
|
||||||
|
return len(p.routes)
|
||||||
|
}
|
||||||
10
client/android/platform_files.go
Normal file
10
client/android/platform_files.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
|
||||||
|
// at their default locations due to android OS restrictions.
|
||||||
|
type PlatformFiles interface {
|
||||||
|
ConfigurationFilePath() string
|
||||||
|
StateFilePath() string
|
||||||
|
}
|
||||||
@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
|||||||
p.configInput.ServerSSHAllowed = &allowed
|
p.configInput.ServerSSHAllowed = &allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHRoot reads SSH root login setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHRoot != nil {
|
||||||
|
return *p.configInput.EnableSSHRoot, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHRoot == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHRoot, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHRoot stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
|
||||||
|
p.configInput.EnableSSHRoot = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHSFTP reads SSH SFTP setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHSFTP != nil {
|
||||||
|
return *p.configInput.EnableSSHSFTP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHSFTP == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHSFTP, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHSFTP stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
|
||||||
|
p.configInput.EnableSSHSFTP = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHLocalPortForwarding != nil {
|
||||||
|
return *p.configInput.EnableSSHLocalPortForwarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHLocalPortForwarding == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHLocalPortForwarding, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
|
||||||
|
p.configInput.EnableSSHLocalPortForwarding = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
|
||||||
|
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
|
||||||
|
if p.configInput.EnableSSHRemotePortForwarding != nil {
|
||||||
|
return *p.configInput.EnableSSHRemotePortForwarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.EnableSSHRemotePortForwarding == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.EnableSSHRemotePortForwarding, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
|
||||||
|
p.configInput.EnableSSHRemotePortForwarding = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
// GetBlockInbound reads block inbound setting from config file
|
// GetBlockInbound reads block inbound setting from config file
|
||||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||||
if p.configInput.BlockInbound != nil {
|
if p.configInput.BlockInbound != nil {
|
||||||
|
|||||||
67
client/android/route_command.go
Normal file
67
client/android/route_command.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func executeRouteToggle(id string, manager routemanager.Manager,
|
||||||
|
operationName string,
|
||||||
|
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
|
||||||
|
netID := route.NetID(id)
|
||||||
|
routes := []route.NetID{netID}
|
||||||
|
|
||||||
|
log.Debugf("%s with id: %s", operationName, id)
|
||||||
|
|
||||||
|
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||||
|
log.Debugf("error when %s: %s", operationName, err)
|
||||||
|
return fmt.Errorf("error %s: %w", operationName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.TriggerSelection(manager.GetClientRoutes())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type routeCommand interface {
|
||||||
|
toggleRoute() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type selectRouteCommand struct {
|
||||||
|
route string
|
||||||
|
manager routemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s selectRouteCommand) toggleRoute() error {
|
||||||
|
routeSelector := s.manager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
return fmt.Errorf("no route selector available")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
|
return routeSelector.SelectRoutes(routes, true, allRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
|
||||||
|
}
|
||||||
|
|
||||||
|
type deselectRouteCommand struct {
|
||||||
|
route string
|
||||||
|
manager routemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d deselectRouteCommand) toggleRoute() error {
|
||||||
|
routeSelector := d.manager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
return fmt.Errorf("no route selector available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
|
||||||
|
}
|
||||||
@@ -4,14 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
hint = profileState.Email
|
hint = profileState.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
if err := openBrowser(verificationURIComplete); err != nil {
|
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
|
||||||
func openBrowser(url string) error {
|
|
||||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
|
||||||
return exec.Command(browser, url).Start()
|
|
||||||
}
|
|
||||||
return open.Run(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ const (
|
|||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
networkMonitorFlag = "network-monitor"
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||||
@@ -64,7 +63,6 @@ var (
|
|||||||
customDNSAddress string
|
customDNSAddress string
|
||||||
rosenpassEnabled bool
|
rosenpassEnabled bool
|
||||||
rosenpassPermissive bool
|
rosenpassPermissive bool
|
||||||
serverSSHAllowed bool
|
|
||||||
interfaceName string
|
interfaceName string
|
||||||
wireguardPort uint16
|
wireguardPort uint16
|
||||||
networkMonitor bool
|
networkMonitor bool
|
||||||
@@ -176,7 +174,6 @@ func init() {
|
|||||||
)
|
)
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||||
|
|
||||||
|
|||||||
@@ -3,125 +3,849 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"os/user"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
port int
|
sshUsernameDesc = "SSH username"
|
||||||
userName = "root"
|
hostArgumentRequired = "host argument required"
|
||||||
host string
|
|
||||||
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
|
enableSSHRootFlag = "enable-ssh-root"
|
||||||
|
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||||
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
|
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var sshCmd = &cobra.Command{
|
var (
|
||||||
Use: "ssh [user@]host",
|
port int
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
username string
|
||||||
if len(args) < 1 {
|
host string
|
||||||
return errors.New("requires a host argument")
|
command string
|
||||||
}
|
localForwards []string
|
||||||
|
remoteForwards []string
|
||||||
|
strictHostKeyChecking bool
|
||||||
|
knownHostsFile string
|
||||||
|
identityFile string
|
||||||
|
skipCachedToken bool
|
||||||
|
requestPTY bool
|
||||||
|
sshNoBrowser bool
|
||||||
|
)
|
||||||
|
|
||||||
split := strings.Split(args[0], "@")
|
var (
|
||||||
if len(split) == 2 {
|
serverSSHAllowed bool
|
||||||
userName = split[0]
|
enableSSHRoot bool
|
||||||
host = split[1]
|
enableSSHSFTP bool
|
||||||
} else {
|
enableSSHLocalPortForward bool
|
||||||
host = args[0]
|
enableSSHRemotePortForward bool
|
||||||
}
|
disableSSHAuth bool
|
||||||
|
sshJWTCacheTTL int
|
||||||
|
)
|
||||||
|
|
||||||
return nil
|
func init() {
|
||||||
},
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||||
Short: "Connect to a remote SSH server",
|
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
SetFlagsFromEnvVars(cmd)
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
|
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
|
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||||
|
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||||
|
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||||
|
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
err := util.InitLog(logLevel, util.LogConsole)
|
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||||
if err != nil {
|
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||||
return fmt.Errorf("failed initializing log %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !util.IsAdmin() {
|
sshCmd.AddCommand(sshSftpCmd)
|
||||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
sshCmd.AddCommand(sshProxyCmd)
|
||||||
return nil
|
sshCmd.AddCommand(sshDetectCmd)
|
||||||
}
|
|
||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager(configPath)
|
|
||||||
activeProf, err := sm.GetActiveProfileState()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get active profile: %v", err)
|
|
||||||
}
|
|
||||||
profPath, err := activeProf.FilePath()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get active profile path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := profilemanager.ReadConfig(profPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read profile config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
|
||||||
sshctx, cancel := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// blocking
|
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
|
||||||
cmd.Printf("Error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-sig:
|
|
||||||
cancel()
|
|
||||||
case <-sshctx.Done():
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
var sshCmd = &cobra.Command{
|
||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
Use: "ssh [flags] [user@]host [command]",
|
||||||
if err != nil {
|
Short: "Connect to a NetBird peer via SSH",
|
||||||
cmd.Printf("Error: %v\n", err)
|
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
|
||||||
"\nYou can verify the connection by running:\n\n" +
|
Port Forwarding:
|
||||||
" netbird status\n\n")
|
-L [bind_address:]port:host:hostport Local port forwarding
|
||||||
return err
|
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||||
}
|
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||||
go func() {
|
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||||
<-ctx.Done()
|
|
||||||
err = c.Close()
|
SSH Options:
|
||||||
if err != nil {
|
-p, --port int Remote SSH port (default 22)
|
||||||
return
|
-u, --user string SSH username
|
||||||
|
--login string SSH username (alias for --user)
|
||||||
|
-t, --tty Force pseudo-terminal allocation
|
||||||
|
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||||
|
-o, --known-hosts string Path to known_hosts file
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird ssh peer-hostname
|
||||||
|
netbird ssh root@peer-hostname
|
||||||
|
netbird ssh --login root peer-hostname
|
||||||
|
netbird ssh peer-hostname ls -la
|
||||||
|
netbird ssh peer-hostname whoami
|
||||||
|
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||||
|
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||||
|
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||||
|
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||||
|
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||||
|
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||||
|
DisableFlagParsing: true,
|
||||||
|
Args: validateSSHArgsWithoutFlagParsing,
|
||||||
|
RunE: sshFn,
|
||||||
|
Aliases: []string{"ssh"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshFn(cmd *cobra.Command, args []string) error {
|
||||||
|
for _, arg := range args {
|
||||||
|
if arg == "-h" || arg == "--help" {
|
||||||
|
return cmd.Help()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
SetFlagsFromEnvVars(cmd)
|
||||||
|
|
||||||
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
logOutput := "console"
|
||||||
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
|
logOutput = firstLogFile
|
||||||
|
}
|
||||||
|
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
sshctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = c.OpenTerminal()
|
select {
|
||||||
if err != nil {
|
case <-sig:
|
||||||
|
cancel()
|
||||||
|
<-sshctx.Done()
|
||||||
|
return nil
|
||||||
|
case err := <-errCh:
|
||||||
return err
|
return err
|
||||||
|
case <-sshctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||||
|
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||||
|
return envValue
|
||||||
|
}
|
||||||
|
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||||
|
return envValue
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||||
|
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||||
|
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetSSHGlobals sets SSH globals to their default values
|
||||||
|
func resetSSHGlobals() {
|
||||||
|
port = sshserver.DefaultSSHPort
|
||||||
|
username = ""
|
||||||
|
host = ""
|
||||||
|
command = ""
|
||||||
|
localForwards = nil
|
||||||
|
remoteForwards = nil
|
||||||
|
strictHostKeyChecking = true
|
||||||
|
knownHostsFile = ""
|
||||||
|
identityFile = ""
|
||||||
|
sshNoBrowser = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||||
|
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||||
|
var localForwardFlags []string
|
||||||
|
var remoteForwardFlags []string
|
||||||
|
var filteredArgs []string
|
||||||
|
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
arg := args[i]
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(arg, "-L"):
|
||||||
|
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
|
||||||
|
case strings.HasPrefix(arg, "-R"):
|
||||||
|
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
|
||||||
|
default:
|
||||||
|
filteredArgs = append(filteredArgs, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
|
||||||
|
if arg == "-L" || arg == "-R" {
|
||||||
|
if i+1 < len(args) {
|
||||||
|
flags = append(flags, args[i+1])
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
} else if len(arg) > 2 {
|
||||||
|
flags = append(flags, arg[2:])
|
||||||
|
}
|
||||||
|
return flags, i
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||||
|
func extractGlobalFlags(args []string) {
|
||||||
|
sshPos := findSSHCommandPosition(args)
|
||||||
|
if sshPos == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
globalArgs := args[:sshPos]
|
||||||
|
parseGlobalArgs(globalArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||||
|
func findSSHCommandPosition(args []string) int {
|
||||||
|
for i, arg := range args {
|
||||||
|
if arg == "ssh" {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
configFlag = "config"
|
||||||
|
logLevelFlag = "log-level"
|
||||||
|
logFileFlag = "log-file"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||||
|
func parseGlobalArgs(globalArgs []string) {
|
||||||
|
flagHandlers := map[string]func(string){
|
||||||
|
configFlag: func(value string) { configPath = value },
|
||||||
|
logLevelFlag: func(value string) { logLevel = value },
|
||||||
|
logFileFlag: func(value string) {
|
||||||
|
if !slices.Contains(logFiles, value) {
|
||||||
|
logFiles = append(logFiles, value)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shortFlags := map[string]string{
|
||||||
|
"c": configFlag,
|
||||||
|
"l": logLevelFlag,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(globalArgs); i++ {
|
||||||
|
arg := globalArgs[i]
|
||||||
|
|
||||||
|
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||||
|
i = nextIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFlag handles generic flag parsing for both long and short forms
|
||||||
|
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||||
|
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||||
|
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||||
|
return true, currentIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||||
|
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||||
|
return true, currentIndex + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, currentIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
type parsedFlag struct {
|
||||||
|
flagName string
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||||
|
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||||
|
if !strings.Contains(arg, "=") {
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(arg, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(parts[0], "--") {
|
||||||
|
flagName := strings.TrimPrefix(parts[0], "--")
|
||||||
|
if _, exists := flagHandlers[flagName]; exists {
|
||||||
|
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||||
|
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||||
|
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||||
|
if _, exists := flagHandlers[longFlag]; exists {
|
||||||
|
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSpacedFormat handles --flag value and -f value formats
|
||||||
|
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||||
|
if currentIndex+1 >= len(args) {
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(arg, "--") {
|
||||||
|
flagName := strings.TrimPrefix(arg, "--")
|
||||||
|
if _, exists := flagHandlers[flagName]; exists {
|
||||||
|
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||||
|
shortFlag := strings.TrimPrefix(arg, "-")
|
||||||
|
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||||
|
if _, exists := flagHandlers[longFlag]; exists {
|
||||||
|
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedFlag{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||||
|
// sshFlags contains all SSH-related flags and parameters
|
||||||
|
type sshFlags struct {
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Login string
|
||||||
|
RequestPTY bool
|
||||||
|
StrictHostKeyChecking bool
|
||||||
|
KnownHostsFile string
|
||||||
|
IdentityFile string
|
||||||
|
SkipCachedToken bool
|
||||||
|
NoBrowser bool
|
||||||
|
ConfigPath string
|
||||||
|
LogLevel string
|
||||||
|
LocalForwards []string
|
||||||
|
RemoteForwards []string
|
||||||
|
Host string
|
||||||
|
Command string
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||||
|
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||||
|
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
|
||||||
|
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||||
|
fs.SetOutput(nil)
|
||||||
|
|
||||||
|
flags := &sshFlags{}
|
||||||
|
|
||||||
|
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||||
|
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||||
|
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||||
|
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||||
|
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||||
|
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||||
|
|
||||||
|
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||||
|
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||||
|
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||||
|
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||||
|
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||||
|
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||||
|
|
||||||
|
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||||
|
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||||
|
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||||
|
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||||
|
|
||||||
|
return fs, flags
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return errors.New(hostArgumentRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
resetSSHGlobals()
|
||||||
|
|
||||||
|
if len(os.Args) > 2 {
|
||||||
|
extractGlobalFlags(os.Args[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||||
|
|
||||||
|
fs, flags := createSSHFlagSet()
|
||||||
|
|
||||||
|
if err := fs.Parse(filteredArgs); err != nil {
|
||||||
|
if errors.Is(err, flag.ErrHelp) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining := fs.Args()
|
||||||
|
if len(remaining) < 1 {
|
||||||
|
return errors.New(hostArgumentRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
port = flags.Port
|
||||||
|
if flags.Username != "" {
|
||||||
|
username = flags.Username
|
||||||
|
} else if flags.Login != "" {
|
||||||
|
username = flags.Login
|
||||||
|
}
|
||||||
|
|
||||||
|
requestPTY = flags.RequestPTY
|
||||||
|
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||||
|
knownHostsFile = flags.KnownHostsFile
|
||||||
|
identityFile = flags.IdentityFile
|
||||||
|
skipCachedToken = flags.SkipCachedToken
|
||||||
|
sshNoBrowser = flags.NoBrowser
|
||||||
|
|
||||||
|
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||||
|
configPath = flags.ConfigPath
|
||||||
|
}
|
||||||
|
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||||
|
logLevel = flags.LogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
localForwards = localForwardFlags
|
||||||
|
remoteForwards = remoteForwardFlags
|
||||||
|
|
||||||
|
return parseHostnameAndCommand(remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHostnameAndCommand(args []string) error {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return errors.New(hostArgumentRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
arg := args[0]
|
||||||
|
if strings.Contains(arg, "@") {
|
||||||
|
parts := strings.SplitN(arg, "@", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return errors.New("invalid user@host format")
|
||||||
|
}
|
||||||
|
if username == "" {
|
||||||
|
username = parts[0]
|
||||||
|
}
|
||||||
|
host = parts[1]
|
||||||
|
} else {
|
||||||
|
host = arg
|
||||||
|
}
|
||||||
|
|
||||||
|
if username == "" {
|
||||||
|
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||||
|
username = sudoUser
|
||||||
|
} else if currentUser, err := user.Current(); err == nil {
|
||||||
|
username = currentUser.Username
|
||||||
|
} else {
|
||||||
|
username = "root"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everything after hostname becomes the command
|
||||||
|
if len(args) > 1 {
|
||||||
|
command = strings.Join(args[1:], " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||||
|
target := fmt.Sprintf("%s:%d", addr, port)
|
||||||
|
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||||
|
KnownHostsFile: knownHostsFile,
|
||||||
|
IdentityFile: identityFile,
|
||||||
|
DaemonAddr: daemonAddr,
|
||||||
|
SkipCachedToken: skipCachedToken,
|
||||||
|
InsecureSkipVerify: !strictHostKeyChecking,
|
||||||
|
NoBrowser: sshNoBrowser,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||||
|
cmd.Printf("\nTroubleshooting steps:\n")
|
||||||
|
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||||
|
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||||
|
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||||
|
return fmt.Errorf("dial %s: %w", target, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-sshCtx.Done()
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||||
|
return fmt.Errorf("start port forwarding: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if command != "" {
|
||||||
|
return executeSSHCommand(sshCtx, c, command)
|
||||||
|
}
|
||||||
|
return openSSHTerminal(sshCtx, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeSSHCommand executes a command over SSH.
|
||||||
|
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||||
|
var err error
|
||||||
|
if requestPTY {
|
||||||
|
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||||
|
} else {
|
||||||
|
err = c.ExecuteCommandWithIO(ctx, command)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitErr *ssh.ExitError
|
||||||
|
if errors.As(err, &exitErr) {
|
||||||
|
os.Exit(exitErr.ExitStatus())
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitMissingErr *ssh.ExitMissingError
|
||||||
|
if errors.As(err, &exitMissingErr) {
|
||||||
|
log.Debugf("Remote command exited without exit status: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("execute command: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openSSHTerminal opens an interactive SSH terminal.
|
||||||
|
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||||
|
if err := c.OpenTerminal(ctx); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitMissingErr *ssh.ExitMissingError
|
||||||
|
if errors.As(err, &exitMissingErr) {
|
||||||
|
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("open terminal: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||||
|
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||||
|
for _, forward := range localForwards {
|
||||||
|
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||||
|
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, forward := range remoteForwards {
|
||||||
|
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||||
|
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||||
|
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||||
|
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
cmd.Printf("Local port forward error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||||
|
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||||
|
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
cmd.Printf("Remote port forward error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||||
|
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||||
|
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||||
|
// Support formats:
|
||||||
|
// port:host:hostport -> localhost:port -> host:hostport
|
||||||
|
// host:port:host:hostport -> host:port -> host:hostport
|
||||||
|
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||||
|
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||||
|
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||||
|
|
||||||
|
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||||
|
return parseIPv6ForwardSpec(spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(spec, ":")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(parts) {
|
||||||
|
case 2:
|
||||||
|
return parseTwoPartForwardSpec(parts, spec)
|
||||||
|
case 3:
|
||||||
|
return parseThreePartForwardSpec(parts)
|
||||||
|
case 4:
|
||||||
|
return parseFourPartForwardSpec(parts)
|
||||||
|
default:
|
||||||
|
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||||
|
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||||
|
if isUnixSocket(parts[1]) {
|
||||||
|
localAddr := "localhost:" + parts[0]
|
||||||
|
remoteAddr := parts[1]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||||
|
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||||
|
if isUnixSocket(parts[2]) {
|
||||||
|
localHost := normalizeLocalHost(parts[0])
|
||||||
|
localAddr := localHost + ":" + parts[1]
|
||||||
|
remoteAddr := parts[2]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
localAddr := "localhost:" + parts[0]
|
||||||
|
remoteAddr := parts[1] + ":" + parts[2]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||||
|
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||||
|
localHost := normalizeLocalHost(parts[0])
|
||||||
|
localAddr := localHost + ":" + parts[1]
|
||||||
|
remoteAddr := parts[2] + ":" + parts[3]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||||
|
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||||
|
idx := strings.Index(spec, "]:")
|
||||||
|
if idx == -1 {
|
||||||
|
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv6Host := spec[:idx+1]
|
||||||
|
remaining := spec[idx+2:]
|
||||||
|
|
||||||
|
parts := strings.Split(remaining, ":")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := ipv6Host + ":" + parts[0]
|
||||||
|
remoteAddr := parts[1] + ":" + parts[2]
|
||||||
|
return localAddr, remoteAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUnixSocket checks if a path is a Unix socket path.
|
||||||
|
func isUnixSocket(path string) bool {
|
||||||
|
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||||
|
func normalizeLocalHost(host string) string {
|
||||||
|
if host == "*" {
|
||||||
|
return "0.0.0.0"
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshProxyCmd = &cobra.Command{
|
||||||
|
Use: "proxy <host> <port>",
|
||||||
|
Short: "Internal SSH proxy for native SSH client integration",
|
||||||
|
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: sshProxyFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||||
|
logOutput := "console"
|
||||||
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
|
logOutput = firstLogFile
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0]
|
||||||
|
portStr := args[1]
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port: %s", portStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||||
|
// where command-line flags cannot be passed. Default is to open browser.
|
||||||
|
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !noBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create SSH proxy: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.Close(); err != nil {
|
||||||
|
log.Debugf("close SSH proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||||
|
return fmt.Errorf("SSH proxy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshDetectCmd = &cobra.Command{
|
||||||
|
Use: "detect <host> <port>",
|
||||||
|
Short: "Detect if a host is running NetBird SSH",
|
||||||
|
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: sshDetectFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||||
|
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
if err := util.InitLog(detectLogLevel, "console"); err != nil {
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0]
|
||||||
|
portStr := args[1]
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid port %q: %v", portStr, err)
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
|
||||||
|
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("SSH server detection failed: %v", err)
|
||||||
|
cancel()
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
os.Exit(serverType.ExitCode())
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sshExecUID uint32
|
||||||
|
sshExecGID uint32
|
||||||
|
sshExecGroups []uint
|
||||||
|
sshExecWorkingDir string
|
||||||
|
sshExecShell string
|
||||||
|
sshExecCommand string
|
||||||
|
sshExecPTY bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||||
|
var sshExecCmd = &cobra.Command{
|
||||||
|
Use: "exec",
|
||||||
|
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: runSSHExec,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||||
|
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||||
|
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||||
|
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||||
|
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||||
|
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||||
|
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||||
|
|
||||||
|
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshCmd.AddCommand(sshExecCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSSHExec handles the SSH exec subcommand execution.
|
||||||
|
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||||
|
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||||
|
|
||||||
|
var groups []uint32
|
||||||
|
for _, groupInt := range sshExecGroups {
|
||||||
|
groups = append(groups, uint32(groupInt))
|
||||||
|
}
|
||||||
|
|
||||||
|
config := sshserver.ExecutorConfig{
|
||||||
|
UID: sshExecUID,
|
||||||
|
GID: sshExecGID,
|
||||||
|
Groups: groups,
|
||||||
|
WorkingDir: sshExecWorkingDir,
|
||||||
|
Shell: sshExecShell,
|
||||||
|
Command: sshExecCommand,
|
||||||
|
PTY: sshExecPTY,
|
||||||
|
}
|
||||||
|
|
||||||
|
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sftpUID uint32
|
||||||
|
sftpGID uint32
|
||||||
|
sftpGroupsInt []uint
|
||||||
|
sftpWorkingDir string
|
||||||
|
)
|
||||||
|
|
||||||
|
var sshSftpCmd = &cobra.Command{
|
||||||
|
Use: "sftp",
|
||||||
|
Short: "SFTP server with privilege dropping (internal use)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: sftpMain,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||||
|
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||||
|
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||||
|
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||||
|
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||||
|
|
||||||
|
var groups []uint32
|
||||||
|
for _, groupInt := range sftpGroupsInt {
|
||||||
|
groups = append(groups, uint32(groupInt))
|
||||||
|
}
|
||||||
|
|
||||||
|
config := sshserver.ExecutorConfig{
|
||||||
|
UID: sftpUID,
|
||||||
|
GID: sftpGID,
|
||||||
|
Groups: groups,
|
||||||
|
WorkingDir: sftpWorkingDir,
|
||||||
|
Shell: "",
|
||||||
|
Command: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||||
|
|
||||||
|
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||||
|
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.WorkingDir != "" {
|
||||||
|
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||||
|
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sftpServer, err := sftp.NewServer(struct {
|
||||||
|
io.Reader
|
||||||
|
io.WriteCloser
|
||||||
|
}{
|
||||||
|
Reader: os.Stdin,
|
||||||
|
WriteCloser: os.Stdout,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("starting SFTP server with dropped privileges")
|
||||||
|
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||||
|
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||||
|
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||||
|
}
|
||||||
|
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||||
|
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||||
|
}
|
||||||
|
os.Exit(sshserver.ExitCodeSuccess)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
94
client/cmd/ssh_sftp_windows.go
Normal file
94
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sftpWorkingDir string
|
||||||
|
windowsUsername string
|
||||||
|
windowsDomain string
|
||||||
|
)
|
||||||
|
|
||||||
|
var sshSftpCmd = &cobra.Command{
|
||||||
|
Use: "sftp",
|
||||||
|
Short: "SFTP server with user switching for Windows (internal use)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: sftpMain,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||||
|
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||||
|
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||||
|
return sftpMainDirect(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sftpMainDirect(cmd *cobra.Command) error {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodeValidationFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
if windowsUsername != "" {
|
||||||
|
expectedUsername := windowsUsername
|
||||||
|
if windowsDomain != "" {
|
||||||
|
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||||
|
cmd.PrintErrf("user switching failed\n")
|
||||||
|
os.Exit(sshserver.ExitCodeValidationFail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||||
|
|
||||||
|
if sftpWorkingDir != "" {
|
||||||
|
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||||
|
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sftpServer, err := sftp.NewServer(struct {
|
||||||
|
io.Reader
|
||||||
|
io.WriteCloser
|
||||||
|
}{
|
||||||
|
Reader: os.Stdin,
|
||||||
|
WriteCloser: os.Stdout,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||||
|
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("starting SFTP server")
|
||||||
|
exitCode := sshserver.ExitCodeSuccess
|
||||||
|
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||||
|
exitCode = sshserver.ExitCodeShellExecFail
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sftpServer.Close(); err != nil {
|
||||||
|
log.Debugf("SFTP server close error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(exitCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
717
client/cmd/ssh_test.go
Normal file
717
client/cmd/ssh_test.go
Normal file
@@ -0,0 +1,717 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedUser string
|
||||||
|
expectedPort int
|
||||||
|
expectedCmd string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic host",
|
||||||
|
args: []string{"hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user@host format",
|
||||||
|
args: []string{"user@hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "user",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with command",
|
||||||
|
args: []string{"hostname", "echo", "hello"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "echo hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with flags should be preserved",
|
||||||
|
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "ls -la /tmp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "double dash separator",
|
||||||
|
args: []string{"hostname", "--", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedUser: "",
|
||||||
|
expectedPort: 22,
|
||||||
|
expectedCmd: "-- ls -la",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
// Mock command for testing
|
||||||
|
cmd := sshCmd
|
||||||
|
cmd.SetArgs(tt.args)
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
if tt.expectedUser != "" {
|
||||||
|
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expectedPort, port, "port mismatch")
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||||
|
// Test that SSH flags don't conflict with command flags
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedCmd string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ls with -la flags",
|
||||||
|
args: []string{"hostname", "ls", "-la"},
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
description: "ls flags should be passed to remote command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "grep with -r flag",
|
||||||
|
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
|
||||||
|
expectedCmd: "grep -r pattern /path",
|
||||||
|
description: "grep flags should be passed to remote command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ps with aux flags",
|
||||||
|
args: []string{"hostname", "ps", "aux"},
|
||||||
|
expectedCmd: "ps aux",
|
||||||
|
description: "ps flags should be passed to remote command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with double dash",
|
||||||
|
args: []string{"hostname", "--", "ls", "-la"},
|
||||||
|
expectedCmd: "-- ls -la",
|
||||||
|
description: "double dash should be preserved in command",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||||
|
// Test that commands with arguments should execute the command and exit,
|
||||||
|
// not drop to an interactive shell
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedCmd string
|
||||||
|
shouldExit bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ls command should execute and exit",
|
||||||
|
args: []string{"hostname", "ls"},
|
||||||
|
expectedCmd: "ls",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "ls command should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ls with flags should execute and exit",
|
||||||
|
args: []string{"hostname", "ls", "-la"},
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "ls with flags should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pwd command should execute and exit",
|
||||||
|
args: []string{"hostname", "pwd"},
|
||||||
|
expectedCmd: "pwd",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "pwd command should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "echo command should execute and exit",
|
||||||
|
args: []string{"hostname", "echo", "hello"},
|
||||||
|
expectedCmd: "echo hello",
|
||||||
|
shouldExit: true,
|
||||||
|
description: "echo command should execute and exit, not drop to shell",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no command should open shell",
|
||||||
|
args: []string{"hostname"},
|
||||||
|
expectedCmd: "",
|
||||||
|
shouldExit: false,
|
||||||
|
description: "no command should open interactive shell",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
|
||||||
|
// When command is present, it should execute the command and exit
|
||||||
|
// When command is empty, it should open interactive shell
|
||||||
|
hasCommand := command != ""
|
||||||
|
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||||
|
// Test that flags after hostname are not parsed by netbird but passed to SSH command
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedCmd string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ls with -la flag should not be parsed by netbird",
|
||||||
|
args: []string{"debian2", "ls", "-la"},
|
||||||
|
expectedHost: "debian2",
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
expectError: false,
|
||||||
|
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with netbird-like flags should be passed through",
|
||||||
|
args: []string{"hostname", "echo", "--help"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "echo --help",
|
||||||
|
expectError: false,
|
||||||
|
description: "--help should be passed to echo, not parsed by netbird",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with -p flag should not conflict with SSH port flag",
|
||||||
|
args: []string{"hostname", "ps", "-p", "1234"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "ps -p 1234",
|
||||||
|
expectError: false,
|
||||||
|
description: "ps -p should be passed to ps command, not parsed as port",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tar with flags should be passed through",
|
||||||
|
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "tar -czf backup.tar.gz /home",
|
||||||
|
expectError: false,
|
||||||
|
description: "tar flags should be passed to tar command",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||||
|
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
|
||||||
|
// should not parse -la as netbird flags but pass them to the SSH command
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedCmd string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "original issue: ls -la should be preserved",
|
||||||
|
args: []string{"debian2", "ls", "-la"},
|
||||||
|
expectedHost: "debian2",
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
expectError: false,
|
||||||
|
description: "The original failing case should now work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ls -l should be preserved",
|
||||||
|
args: []string{"hostname", "ls", "-l"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "ls -l",
|
||||||
|
expectError: false,
|
||||||
|
description: "Single letter flags should be preserved",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SSH port flag should work",
|
||||||
|
args: []string{"-p", "2222", "hostname", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
expectError: false,
|
||||||
|
description: "SSH -p flag should be parsed, command flags preserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||||
|
|
||||||
|
// Check port for the test case with -p flag
|
||||||
|
if len(tt.args) > 0 && tt.args[0] == "-p" {
|
||||||
|
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedLocal []string
|
||||||
|
expectedRemote []string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "local port forwarding -L",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Single -L flag should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remote port forwarding -R",
|
||||||
|
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{},
|
||||||
|
expectedRemote: []string{"8080:localhost:80"},
|
||||||
|
expectError: false,
|
||||||
|
description: "Single -R flag should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple local port forwards",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Multiple -L flags should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple remote port forwards",
|
||||||
|
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{},
|
||||||
|
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||||
|
expectError: false,
|
||||||
|
description: "Multiple -R flags should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed local and remote forwards",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{"9090:localhost:443"},
|
||||||
|
expectError: false,
|
||||||
|
description: "Mixed -L and -R flags should be parsed correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forwarding with bind address",
|
||||||
|
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Port forwarding with bind address should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forwarding with command",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectError: false,
|
||||||
|
description: "Port forwarding with command should work",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
localForwards = nil
|
||||||
|
remoteForwards = nil
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
// Handle nil vs empty slice comparison
|
||||||
|
if len(tt.expectedLocal) == 0 {
|
||||||
|
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||||
|
}
|
||||||
|
if len(tt.expectedRemote) == 0 {
|
||||||
|
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePortForward(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
spec string
|
||||||
|
expectedLocal string
|
||||||
|
expectedRemote string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple port forward",
|
||||||
|
spec: "8080:localhost:80",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "localhost:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "Simple port:host:port format should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forward with bind address",
|
||||||
|
spec: "127.0.0.1:8080:localhost:80",
|
||||||
|
expectedLocal: "127.0.0.1:8080",
|
||||||
|
expectedRemote: "localhost:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "bind_address:port:host:port format should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forward to different host",
|
||||||
|
spec: "8080:example.com:443",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "example.com:443",
|
||||||
|
expectError: false,
|
||||||
|
description: "Forwarding to different host should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port forward with IPv6 (needs bracket support)",
|
||||||
|
spec: "::1:8080:localhost:80",
|
||||||
|
expectError: true,
|
||||||
|
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid format - too few parts",
|
||||||
|
spec: "8080:localhost",
|
||||||
|
expectError: true,
|
||||||
|
description: "Invalid format with too few parts should fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid format - too many parts",
|
||||||
|
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||||
|
expectError: true,
|
||||||
|
description: "Invalid format with too many parts should fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty spec",
|
||||||
|
spec: "",
|
||||||
|
expectError: true,
|
||||||
|
description: "Empty spec should fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket local forward",
|
||||||
|
spec: "8080:/tmp/socket",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "/tmp/socket",
|
||||||
|
expectError: false,
|
||||||
|
description: "Unix socket forwarding should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket with bind address",
|
||||||
|
spec: "127.0.0.1:8080:/tmp/socket",
|
||||||
|
expectedLocal: "127.0.0.1:8080",
|
||||||
|
expectedRemote: "/tmp/socket",
|
||||||
|
expectError: false,
|
||||||
|
description: "Unix socket with bind address should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard bind all interfaces",
|
||||||
|
spec: "*:8080:localhost:80",
|
||||||
|
expectedLocal: "0.0.0.0:8080",
|
||||||
|
expectedRemote: "localhost:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard for port only",
|
||||||
|
spec: "8080:*:80",
|
||||||
|
expectedLocal: "localhost:8080",
|
||||||
|
expectedRemote: "*:80",
|
||||||
|
expectError: false,
|
||||||
|
description: "Wildcard in remote host should be preserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err, tt.description)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err, tt.description)
|
||||||
|
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||||
|
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||||
|
// Integration test for port forwarding with the actual SSH command implementation
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedHost string
|
||||||
|
expectedLocal []string
|
||||||
|
expectedRemote []string
|
||||||
|
expectedCmd string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "local forward with command",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{},
|
||||||
|
expectedCmd: "echo test",
|
||||||
|
description: "Local forwarding should work with commands",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remote forward with command",
|
||||||
|
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{},
|
||||||
|
expectedRemote: []string{"8080:localhost:80"},
|
||||||
|
expectedCmd: "ls -la",
|
||||||
|
description: "Remote forwarding should work with commands",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple forwards with user and command",
|
||||||
|
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||||
|
expectedHost: "hostname",
|
||||||
|
expectedLocal: []string{"8080:localhost:80"},
|
||||||
|
expectedRemote: []string{"9090:localhost:443"},
|
||||||
|
expectedCmd: "ps aux",
|
||||||
|
description: "Complex case with multiple forwards, user, and command",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
localForwards = nil
|
||||||
|
remoteForwards = nil
|
||||||
|
|
||||||
|
cmd := sshCmd
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||||
|
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||||
|
// Handle nil vs empty slice comparison
|
||||||
|
if len(tt.expectedLocal) == 0 {
|
||||||
|
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||||
|
}
|
||||||
|
if len(tt.expectedRemote) == 0 {
|
||||||
|
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expectedCmd string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cmd flag passed as command",
|
||||||
|
args: []string{"hostname", "--cmd", "echo test"},
|
||||||
|
expectedCmd: "--cmd echo test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uid flag passed as command",
|
||||||
|
args: []string{"hostname", "--uid", "1000"},
|
||||||
|
expectedCmd: "--uid 1000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "shell flag passed as command",
|
||||||
|
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||||
|
expectedCmd: "--shell /bin/bash",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "hostname", host)
|
||||||
|
assert.Equal(t, tt.expectedCmd, command)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||||
|
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid long flag before hostname",
|
||||||
|
args: []string{"--invalid-flag", "hostname"},
|
||||||
|
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid short flag before hostname",
|
||||||
|
args: []string{"-x", "hostname"},
|
||||||
|
description: "Invalid short flag should return parse error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid flag with value before hostname",
|
||||||
|
args: []string{"--invalid-option=value", "hostname"},
|
||||||
|
description: "Invalid flag with value should return parse error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "typo in known flag",
|
||||||
|
args: []string{"--por", "2222", "hostname"},
|
||||||
|
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||||
|
|
||||||
|
// Should return an error for invalid flags
|
||||||
|
assert.Error(t, err, tt.description)
|
||||||
|
|
||||||
|
// Should not have set host to the invalid flag
|
||||||
|
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,6 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -20,8 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -84,7 +88,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -110,13 +113,21 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
ctx := context.Background()
|
||||||
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
req.ServerSSHAllowed = &serverSSHAllowed
|
req.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
|
req.EnableSSHRoot = &enableSSHRoot
|
||||||
|
}
|
||||||
|
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||||
|
req.EnableSSHSFTP = &enableSSHSFTP
|
||||||
|
}
|
||||||
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
|
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
|
}
|
||||||
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
|
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||||
|
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||||
|
}
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
log.Errorf("parse interface name: %v", err)
|
log.Errorf("parse interface name: %v", err)
|
||||||
@@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
|
ic.EnableSSHRoot = &enableSSHRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||||
|
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
|
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
|
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
|
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||||
|
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
|
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
|
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||||
|
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,12 +18,16 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
var (
|
||||||
var ErrClientNotStarted = errors.New("client not started")
|
ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrClientNotStarted = errors.New("client not started")
|
||||||
|
ErrEngineNotStarted = errors.New("engine not started")
|
||||||
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
|||||||
// Dial dials a network address in the netbird network.
|
// Dial dials a network address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
c.mu.Lock()
|
engine, err := c.getEngine()
|
||||||
connect := c.connect
|
if err != nil {
|
||||||
if connect == nil {
|
return nil, err
|
||||||
c.mu.Unlock()
|
|
||||||
return nil, ErrClientNotStarted
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
engine := connect.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil, errors.New("engine not started")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nsnet, err := engine.GetNet()
|
nsnet, err := engine.GetNet()
|
||||||
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
|||||||
return nsnet.DialContext(ctx, network, address)
|
return nsnet.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialContext dials a network address in the netbird network with context
|
||||||
|
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return c.Dial(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
// ListenTCP listens on the given address in the netbird network.
|
// ListenTCP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||||
|
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||||
|
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||||
|
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||||
|
if !found {
|
||||||
|
return sshcommon.ErrPeerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEngine safely retrieves the engine from the client with proper locking.
|
||||||
|
// Returns ErrClientNotStarted if the client is not started.
|
||||||
|
// Returns ErrEngineNotStarted if the engine is not available.
|
||||||
|
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
if connect == nil {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return nil, netip.Addr{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if connect == nil {
|
||||||
|
return nil, ErrClientNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
return nil, netip.Addr{}, errors.New("engine not started")
|
return nil, ErrEngineNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
return engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := engine.Address()
|
addr, err := engine.Address()
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/nadoo/ipset"
|
ipset "github.com/lrh3321/ipset-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -40,19 +41,13 @@ type aclManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
m := &aclManager{
|
return &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
if err := ipset.Init(); err != nil {
|
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||||
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
specs = append(specs, "-j", actionToStr(action))
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||||
}
|
}
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := ipset.Create(ipsetName); err != nil {
|
if err := m.createIPSet(ipsetName); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
return nil, fmt.Errorf("create ipset: %w", err)
|
||||||
}
|
}
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipList := newIpList(ip.String())
|
ipList := newIpList(ip.String())
|
||||||
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shouldDestroyIpset := false
|
||||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||||
// delete IP from ruleset IPs list and ipset
|
// delete IP from ruleset IPs list and ipset
|
||||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
ip := net.ParseIP(r.ip)
|
||||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
if ip == nil {
|
||||||
|
return fmt.Errorf("parse IP %s", r.ip)
|
||||||
|
}
|
||||||
|
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||||
|
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||||
}
|
}
|
||||||
delete(ipsetList.ips, r.ip)
|
delete(ipsetList.ips, r.ip)
|
||||||
}
|
}
|
||||||
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
// we delete last IP from the set, that means we need to delete
|
// we delete last IP from the set, that means we need to delete
|
||||||
// set itself and associated firewall rule too
|
// set itself and associated firewall rule too
|
||||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||||
|
shouldDestroyIpset = true
|
||||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
|
||||||
log.Errorf("delete empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||||
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shouldDestroyIpset {
|
||||||
|
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||||
|
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("destroy empty ipset: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("destroy empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := ipset.Destroy(ipsetName); err != nil {
|
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m.ipsetStore.deleteIpset(ipsetName)
|
m.ipsetStore.deleteIpset(ipsetName)
|
||||||
}
|
}
|
||||||
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
|
|||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is 0.0.0.0
|
||||||
if ip.String() == "0.0.0.0" {
|
if ip.IsUnspecified() {
|
||||||
matchByIP = false
|
matchByIP = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
|||||||
return ipsetName + actionSuffix
|
return ipsetName + actionSuffix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) createIPSet(name string) error {
|
||||||
|
opts := ipset.CreateOptions{
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("created ipset %s with type hash:net", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||||
|
cidr := uint8(32)
|
||||||
|
if ip.To4() == nil {
|
||||||
|
cidr = 128
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: cidr,
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||||
|
cidr := uint8(32)
|
||||||
|
if ip.To4() == nil {
|
||||||
|
cidr = 128
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: cidr,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Del(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) flushIPSet(name string) error {
|
||||||
|
return ipset.Flush(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) destroyIPSet(name string) error {
|
||||||
|
return ipset.Destroy(name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/nadoo/ipset"
|
ipset "github.com/lrh3321/ipset-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
@@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := ipset.Init(); err != nil {
|
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
if err := r.createIPSet(setName); err != nil {
|
||||||
return fmt.Errorf("create set %s: %w", setName, err)
|
return fmt.Errorf("create set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, prefix := range sources {
|
for _, prefix := range sources {
|
||||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string) error {
|
func (r *router) deleteIpSet(setName string) error {
|
||||||
if err := ipset.Destroy(setName); err != nil {
|
if err := r.destroyIPSet(setName); err != nil {
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
|
|
||||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) createIPSet(name string) error {
|
||||||
|
opts := ipset.CreateOptions{
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("created ipset %s with type hash:net", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
ip := addr.AsSlice()
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: uint8(prefix.Bits()),
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) destroyIPSet(name string) error {
|
||||||
|
return ipset.Destroy(name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,7 +27,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
tableRaw = "raw"
|
||||||
|
tableSecurity = "security"
|
||||||
|
|
||||||
chainNameNatPrerouting = "PREROUTING"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
var err error
|
var err error
|
||||||
r.filterTable, err = r.loadFilterTable()
|
r.filterTable, err = r.loadFilterTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errFilterTableNotFound) {
|
log.Debugf("ip filter table not found: %v", err)
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("load filter table: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|||||||
return nil, errFilterTableNotFound
|
return nil, errFilterTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hookName(hook *nftables.ChainHook) string {
|
||||||
|
if hook == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
switch *hook {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
return chainNameForward
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
return chainNameInput
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("hook(%d)", *hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyName(family nftables.TableFamily) string {
|
||||||
|
switch family {
|
||||||
|
case nftables.TableFamilyIPv4:
|
||||||
|
return "ip"
|
||||||
|
case nftables.TableFamilyIPv6:
|
||||||
|
return "ip6"
|
||||||
|
case nftables.TableFamilyINet:
|
||||||
|
return "inet"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("family(%d)", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingFw,
|
Name: chainNameRoutingFw,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
})
|
})
|
||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
|
||||||
return fmt.Errorf("add single nat rule: %v", err)
|
r.addPostroutingRules()
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.addMSSClampingRules(); err != nil {
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("initialize tables: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addPostroutingRules adds the masquerade rules
|
// addPostroutingRules adds the masquerade rules
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() {
|
||||||
// First masquerade rule for traffic coming in from WireGuard interface
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
// Match on the first fwmark
|
// Match on the first fwmark
|
||||||
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
Chain: r.chains[chainNameRoutingNat],
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
Exprs: exprs2,
|
Exprs: exprs2,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
|
|||||||
Exprs: exprsOut,
|
Exprs: exprsOut,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return r.conn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||||
func (r *router) acceptForwardRules() error {
|
func (r *router) acceptForwardRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.acceptFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
|||||||
// Try iptables first and fallback to nftables if iptables is not available
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// filter table exists but iptables is not
|
// iptables is not available but the filter table exists
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables()
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.acceptFilterRulesIptables(ipt)
|
return r.acceptFilterRulesIptables(ipt)
|
||||||
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables forward rule: %v", rule)
|
log.Debugf("added iptables forward rule: %v", rule)
|
||||||
}
|
}
|
||||||
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables input rule: %v", inputRule)
|
log.Debugf("added iptables input rule: %v", inputRule)
|
||||||
}
|
}
|
||||||
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
|||||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesNftables() error {
|
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||||
|
// This is used when iptables is not available.
|
||||||
|
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
forwardChain := &nftables.Chain{
|
||||||
|
Name: chainNameForward,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertForwardAcceptRules(forwardChain, intf)
|
||||||
|
|
||||||
|
inputChain := &nftables.Chain{
|
||||||
|
Name: chainNameInput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertInputAcceptRule(inputChain, intf)
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||||
|
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||||
|
func (r *router) acceptExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
iifRule := &nftables.Rule{
|
iifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
inputRule := &nftables.Rule{
|
inputRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameInput,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
UserData: []byte(userDataAcceptInputRule),
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(inputRule)
|
r.conn.InsertRule(inputRule)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRules() error {
|
func (r *router) removeAcceptFilterRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptFilterRulesNftables()
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.removeAcceptFilterRulesIptables(ipt)
|
return r.removeAcceptFilterRulesIptables(ipt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list chains: %v", err)
|
return fmt.Errorf("list chains: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
if chain.Table.Name != table.Name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1100,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||||
|
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||||
|
// ensuring cleanup works even after a crash or if chains changed.
|
||||||
|
func (r *router) removeExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||||
|
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||||
|
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||||
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
|
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||||
|
|
||||||
|
for _, family := range families {
|
||||||
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get rules: %v", err)
|
log.Debugf("list chains for family %d: %v", family, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, chain := range allChains {
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
if r.isExternalChain(chain) {
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
chains = append(chains, chain)
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
return chains
|
||||||
return fmt.Errorf(flushError, err)
|
}
|
||||||
|
|
||||||
|
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||||
|
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Skip all iptables-managed tables in the ip family
|
||||||
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Type != nftables.ChainTypeFilter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesTable(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
@@ -1128,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(" unable to list rules: %v", err)
|
return fmt.Errorf("list rules: %w", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
|
|||||||
@@ -35,6 +35,12 @@ const (
|
|||||||
ipTCPHeaderMinSize = 40
|
ipTCPHeaderMinSize = 40
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// serviceKey represents a protocol/port combination for netstack service registry
|
||||||
|
type serviceKey struct {
|
||||||
|
protocol gopacket.LayerType
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
@@ -59,12 +65,6 @@ const (
|
|||||||
|
|
||||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
|
|
||||||
// serviceKey represents a protocol/port combination for netstack service registry
|
|
||||||
type serviceKey struct {
|
|
||||||
protocol gopacket.LayerType
|
|
||||||
port uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
|
|||||||
|
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldForward(t *testing.T) {
|
||||||
|
// Set up test addresses
|
||||||
|
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||||
|
|
||||||
|
// Create test manager with mock interface
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
// Set the mock to return our test WG IP
|
||||||
|
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Helper to create decoder with TCP packet
|
||||||
|
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
SrcIP: net.ParseIP("192.168.1.100"),
|
||||||
|
DstIP: wgIP.AsSlice(),
|
||||||
|
}
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: 54321,
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
|
||||||
|
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
localForwarding bool
|
||||||
|
netstack bool
|
||||||
|
dstIP netip.Addr
|
||||||
|
serviceRegistered bool
|
||||||
|
servicePort uint16
|
||||||
|
expected bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no local forwarding",
|
||||||
|
localForwarding: false,
|
||||||
|
netstack: true,
|
||||||
|
dstIP: wgIP,
|
||||||
|
expected: false,
|
||||||
|
description: "should never forward when local forwarding disabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to other local interface",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: false,
|
||||||
|
dstIP: otherIP,
|
||||||
|
expected: true,
|
||||||
|
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to NetBird IP, no netstack",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: false,
|
||||||
|
dstIP: wgIP,
|
||||||
|
expected: false,
|
||||||
|
description: "should send to netstack listeners (final return false path)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to our IP, netstack mode, no service",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: true,
|
||||||
|
dstIP: wgIP,
|
||||||
|
expected: true,
|
||||||
|
description: "should forward when in netstack mode with no matching service",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "traffic to our IP, netstack mode, with service",
|
||||||
|
localForwarding: true,
|
||||||
|
netstack: true,
|
||||||
|
dstIP: wgIP,
|
||||||
|
serviceRegistered: true,
|
||||||
|
servicePort: 22,
|
||||||
|
expected: false,
|
||||||
|
description: "should send to netstack listeners when service is registered",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Configure manager
|
||||||
|
manager.localForwarding = tt.localForwarding
|
||||||
|
manager.netstack = tt.netstack
|
||||||
|
|
||||||
|
// Register service if needed
|
||||||
|
if tt.serviceRegistered {
|
||||||
|
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||||
|
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create decoder for the test
|
||||||
|
decoder := createTCPDecoder(tt.servicePort)
|
||||||
|
if !tt.serviceRegistered {
|
||||||
|
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the method
|
||||||
|
result := manager.shouldForward(decoder, tt.dstIP)
|
||||||
|
require.Equal(t, tt.expected, result, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestPortDNATBasic tests basic port DNAT functionality
|
||||||
|
func TestPortDNATBasic(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Define peer IPs
|
||||||
|
peerA := netip.MustParseAddr("100.10.0.50")
|
||||||
|
peerB := netip.MustParseAddr("100.10.0.51")
|
||||||
|
|
||||||
|
// Add SSH port redirection rule for peer B (the target)
|
||||||
|
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||||
|
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||||
|
d := parsePacket(t, packetAtoB)
|
||||||
|
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
|
||||||
|
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||||
|
|
||||||
|
// Verify port was translated to 22022
|
||||||
|
d = parsePacket(t, packetAtoB)
|
||||||
|
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||||
|
|
||||||
|
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
|
||||||
|
// (prevents double NAT - original port stored in conntrack)
|
||||||
|
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||||
|
d2 := parsePacket(t, returnPacket)
|
||||||
|
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
|
||||||
|
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||||
|
func TestPortDNATMultipleRules(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Define peer IPs
|
||||||
|
peerA := netip.MustParseAddr("100.10.0.50")
|
||||||
|
peerB := netip.MustParseAddr("100.10.0.51")
|
||||||
|
|
||||||
|
// Add SSH port redirection rules for both peers
|
||||||
|
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test traffic to peer B gets translated
|
||||||
|
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||||
|
d1 := parsePacket(t, packetToB)
|
||||||
|
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
|
||||||
|
require.True(t, translatedToB, "Traffic to peer B should be translated")
|
||||||
|
d1 = parsePacket(t, packetToB)
|
||||||
|
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
|
||||||
|
|
||||||
|
// Test traffic to peer A gets translated
|
||||||
|
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||||
|
d2 := parsePacket(t, packetToA)
|
||||||
|
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
|
||||||
|
require.True(t, translatedToA, "Traffic to peer A should be translated")
|
||||||
|
d2 = parsePacket(t, packetToA)
|
||||||
|
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,7 +11,6 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -20,9 +18,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
|
||||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
|
||||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
|
||||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
|
||||||
conn.Connect()
|
|
||||||
|
|
||||||
state := conn.GetState()
|
|
||||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
|
||||||
if !conn.WaitForStateChange(ctx, state) {
|
|
||||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
|
||||||
}
|
|
||||||
state = conn.GetState()
|
|
||||||
}
|
|
||||||
|
|
||||||
if state == connectivity.Shutdown {
|
|
||||||
return ErrConnectionShutdown
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
@@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := grpc.NewClient(
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(
|
||||||
|
connCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new client: %w", err)
|
return nil, fmt.Errorf("dial context: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -19,11 +20,12 @@ import (
|
|||||||
|
|
||||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||||
type WGTunDevice struct {
|
type WGTunDevice struct {
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu uint16
|
mtu uint16
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
|
// todo: review if we can eliminate the TunAdapter
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
|
|
||||||
@@ -32,17 +34,19 @@ type WGTunDevice struct {
|
|||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
|
renewableTun *RenewableTUN
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
disableDNS: disableDNS,
|
disableDNS: disableDNS,
|
||||||
|
renewableTun: NewRenewableTUN(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||||
|
|
||||||
t.name = name
|
t.name = name
|
||||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
t.filteredDevice = newDeviceFilter(t.renewableTun)
|
||||||
|
|
||||||
log.Debugf("attaching to interface %v", name)
|
log.Debugf("attaching to interface %v", name)
|
||||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||||
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *WGTunDevice) RenewTun(fd int) error {
|
||||||
|
if t.device == nil {
|
||||||
|
return fmt.Errorf("device not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
log.Errorf("failed to renew Android interface: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||||
return t.create()
|
return t.create()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunNetstackDevice) RenewTun(fd int) error {
|
||||||
|
// Doesn't make sense in Android for Netstack.
|
||||||
|
return fmt.Errorf("this function has not been implemented in Netstack for Android")
|
||||||
|
}
|
||||||
|
|||||||
309
client/iface/device/renewable_tun.go
Normal file
309
client/iface/device/renewable_tun.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// closeAwareDevice wraps a tun.Device along with a flag
|
||||||
|
// indicating whether its Close method was called.
|
||||||
|
//
|
||||||
|
// It also redirects tun.Device's Events() to a separate goroutine
|
||||||
|
// and closes it when Close is called.
|
||||||
|
//
|
||||||
|
// The WaitGroup and CloseOnce fields are used to ensure that the
|
||||||
|
// goroutine is awaited and closed only once.
|
||||||
|
type closeAwareDevice struct {
|
||||||
|
isClosed atomic.Bool
|
||||||
|
tun.Device
|
||||||
|
closeEventCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||||
|
return &closeAwareDevice{
|
||||||
|
Device: tunDevice,
|
||||||
|
isClosed: atomic.Bool{},
|
||||||
|
closeEventCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||||
|
// to the given channel (RenewableTUN's events channel).
|
||||||
|
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-c.Device.Events():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ev == tun.EventDown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case out <- ev:
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the underlying Device's Close method
|
||||||
|
// after setting isClosed to true.
|
||||||
|
func (c *closeAwareDevice) Close() (err error) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.isClosed.Store(true)
|
||||||
|
close(c.closeEventCh)
|
||||||
|
err = c.Device.Close()
|
||||||
|
c.wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeAwareDevice) IsClosed() bool {
|
||||||
|
return c.isClosed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
type RenewableTUN struct {
|
||||||
|
devices []*closeAwareDevice
|
||||||
|
mu sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
events chan tun.Event
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRenewableTUN() *RenewableTUN {
|
||||||
|
r := &RenewableTUN{
|
||||||
|
devices: make([]*closeAwareDevice, 0),
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
events: make(chan tun.Event, 16),
|
||||||
|
}
|
||||||
|
r.cond = sync.NewCond(&r.mu)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) File() *os.File {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
file := dev.File()
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from an underlying tun.Device kept in the r.devices slice.
|
||||||
|
// If no device is available, it waits for one to be added via AddDevice().
|
||||||
|
//
|
||||||
|
// On error, it retries reading from the newest device instead of returning the error
|
||||||
|
// if the device is closed; if not, it propagates the error.
|
||||||
|
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
// wait until AddDevice() signals a new device via cond.Broadcast()
|
||||||
|
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = dev.Read(bufs, sizes, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// swap in progress; retry on the newest instead of returning the error
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, err // propagate non-swap error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to underlying tun.Device kept in the r.devices slice.
|
||||||
|
// If no device is available, it waits for one to be added via AddDevice().
|
||||||
|
//
|
||||||
|
// On error, it retries writing to the newest device instead of returning the error
|
||||||
|
// if the device is closed; if not, it propagates the error.
|
||||||
|
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dev.Write(bufs, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) MTU() (int, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mtu, err := dev.MTU()
|
||||||
|
if err == nil {
|
||||||
|
return mtu, nil
|
||||||
|
}
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) Name() (string, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name, err := dev.Name()
|
||||||
|
if err == nil {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events returns a channel that is fed events from the underlying tun.Device's events channel
|
||||||
|
// once it is added.
|
||||||
|
func (r *RenewableTUN) Events() <-chan tun.Event {
|
||||||
|
return r.events
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) Close() error {
|
||||||
|
// Attempts to set the RenewableTUN closed flag to true.
|
||||||
|
// If it's already true, returns immediately.
|
||||||
|
if !r.closed.CompareAndSwap(false, true) {
|
||||||
|
return nil // already closed: idempotent
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
devices := r.devices
|
||||||
|
r.devices = nil
|
||||||
|
r.cond.Broadcast()
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
log.Debugf("closing %d devices", len(devices))
|
||||||
|
for _, device := range devices {
|
||||||
|
if err := device.Close(); err != nil {
|
||||||
|
log.Debugf("error closing a device: %v", err)
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
close(r.events)
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) AddDevice(device tun.Device) {
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.closed.Load() {
|
||||||
|
r.mu.Unlock()
|
||||||
|
_ = device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var toClose *closeAwareDevice
|
||||||
|
if len(r.devices) > 0 {
|
||||||
|
toClose = r.devices[len(r.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cad := newClosableDevice(device)
|
||||||
|
cad.redirectEvents(r.events)
|
||||||
|
|
||||||
|
r.devices = []*closeAwareDevice{cad}
|
||||||
|
r.cond.Broadcast()
|
||||||
|
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
if toClose != nil {
|
||||||
|
if err := toClose.Close(); err != nil {
|
||||||
|
log.Debugf("error closing last device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) waitForDevice() bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
for len(r.devices) == 0 && !r.closed.Load() {
|
||||||
|
r.cond.Wait()
|
||||||
|
}
|
||||||
|
return !r.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) peekLast() *closeAwareDevice {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if len(r.devices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.devices[len(r.devices)-1]
|
||||||
|
}
|
||||||
@@ -21,5 +21,6 @@ type WGTunDevice interface {
|
|||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
Device() *wgdevice.Device
|
Device() *wgdevice.Device
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
|
RenewTun(fd int) error
|
||||||
GetICEBind() device.EndpointManager
|
GetICEBind() device.EndpointManager
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
|
|||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
return fmt.Errorf("this function has not implemented on non mobile")
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
return fmt.Errorf("this function has not been implemented on non-android")
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
|
// todo: review does this function really necessary or can we merge it with iOS
|
||||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
|||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
return w.tun.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
|
|||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
return fmt.Errorf("this function has not been implemented on this platform")
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -9,13 +10,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
// keep darwin compatibility
|
// keep darwin compatibility
|
||||||
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
addr := "100.64.0.1/8"
|
addr := "100.64.0.1/8"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
|||||||
func Test_CreateInterface(t *testing.T) {
|
func Test_CreateInterface(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||||
wgIP := "10.99.99.1/32"
|
wgIP := "10.99.99.1/32"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||||
wgIP := "10.99.99.5/30"
|
wgIP := "10.99.99.5/30"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
func Test_UpdatePeer(t *testing.T) {
|
func Test_UpdatePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.9/30"
|
wgIP := "10.99.99.9/30"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
func Test_RemovePeer(t *testing.T) {
|
func Test_RemovePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.13/30"
|
wgIP := "10.99.99.13/30"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
peer2wgPort := 33200
|
peer2wgPort := 33200
|
||||||
|
|
||||||
keepAlive := 1 * time.Second
|
keepAlive := 1 * time.Second
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
|
||||||
newNet, err = stdnet.NewNet()
|
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package udpmux
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -12,8 +13,9 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
|||||||
if len(networks) > 0 {
|
if len(networks) > 0 {
|
||||||
if m.params.Net == nil {
|
if m.params.Net == nil {
|
||||||
var err error
|
var err error
|
||||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules := networkMap.FirewallRules
|
rules := networkMap.FirewallRules
|
||||||
|
|
||||||
enableSSH := networkMap.PeerConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
|
||||||
|
|
||||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
|
||||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
|
||||||
if enableSSH {
|
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||||
// we have old version of management without rules handling, we should allow all traffic
|
// we have old version of management without rules handling, we should allow all traffic
|
||||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
||||||
|
|||||||
@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
PeerConfig: &mgmProto.PeerConfig{
|
|
||||||
SshConfig: &mgmProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
|
||||||
IP: network.Addr(),
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
err = fw.Close(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
|
||||||
|
|
||||||
expectedRules := 3
|
|
||||||
if fw.IsStateful() {
|
|
||||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
|
||||||
}
|
|
||||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
return t.AccessToken
|
return t.AccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||||
|
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||||
|
}
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||||
//
|
//
|
||||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||||
|
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||||
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
|
|||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
// find the first available redirect URL
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
|
|
||||||
for _, redirectURL := range config.RedirectURLs {
|
for _, redirectURL := range config.RedirectURLs {
|
||||||
if !isRedirectURLPortUsed(redirectURL) {
|
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||||
availableRedirectURL = redirectURL
|
availableRedirectURL = redirectURL
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
switch p.providerConfig.LoginFlag {
|
||||||
|
case common.LoginFlagPromptLogin:
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
}
|
case common.LoginFlagMaxAge0:
|
||||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
|
||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,17 +195,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
|||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
if authError := query.Get(queryError); authError != "" {
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
if authErrorDesc != "" {
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", authError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent timing attacks on the state
|
// Prevent timing attacks on the state
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
return nil, fmt.Errorf("invalid state")
|
return nil, fmt.Errorf("authentication failed: Invalid state")
|
||||||
}
|
}
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
code := query.Get(queryCode)
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, fmt.Errorf("missing code")
|
return nil, fmt.Errorf("authentication failed: missing code")
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
return p.oAuthConfig.Exchange(
|
||||||
@@ -231,7 +237,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||||
@@ -279,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
|
|||||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||||
parsedURL, err := url.Parse(redirectURL)
|
parsedURL, err := url.Parse(redirectURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse redirect URL: %v", err)
|
log.Errorf("failed to parse redirect URL: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
port := parsedURL.Port()
|
||||||
|
|
||||||
|
if isPortInExcludedRange(port, excludedRanges) {
|
||||||
|
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf(":%s", port)
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
@@ -301,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// excludedPortRange represents a range of excluded ports.
|
||||||
|
type excludedPortRange struct {
|
||||||
|
start int
|
||||||
|
end int
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||||
|
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||||
|
if len(excludedRanges) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid port number %s: %v", port, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range excludedRanges {
|
||||||
|
if portNum >= r.start && portNum <= r.end {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
8
client/internal/auth/pkce_flow_other.go
Normal file
8
client/internal/auth/pkce_flow_other.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,8 +2,11 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
loginFlag mgm.LoginFlag
|
loginFlag mgm.LoginFlag
|
||||||
disablePromptLogin bool
|
disablePromptLogin bool
|
||||||
expect string
|
expectContains []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Prompt login",
|
name: "Prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
expect: promptLogin,
|
expectContains: []string{promptLogin},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Max age 0 login",
|
name: "Max age 0",
|
||||||
loginFlag: mgm.LoginFlagMaxAge0,
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
expect: maxAge0,
|
expectContains: []string{maxAge0},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disable prompt login",
|
name: "Disable prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
disablePromptLogin: true,
|
disablePromptLogin: true,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "None flag should not add parameters",
|
||||||
|
loginFlag: mgm.LoginFlagNone,
|
||||||
|
expectContains: []string{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
LoginFlag: tc.loginFlag,
|
LoginFlag: tc.loginFlag,
|
||||||
|
DisablePromptLogin: tc.disablePromptLogin,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tc.disablePromptLogin {
|
for _, expected := range tc.expectContains {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||||
} else {
|
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPortInExcludedRange(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
port string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedBlocked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at start of range",
|
||||||
|
port: "8000",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at end of range",
|
||||||
|
port: "8100",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port before range",
|
||||||
|
port: "7999",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port after range",
|
||||||
|
port: "8101",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port in second range",
|
||||||
|
port: "9050",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port not in any range",
|
||||||
|
port: "8500",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid port string",
|
||||||
|
port: "invalid",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty port string",
|
||||||
|
port: "",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURL string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedUsed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
redirectURL: "http://127.0.0.1:8080/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port actually in use",
|
||||||
|
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port not in use and not excluded",
|
||||||
|
redirectURL: "http://127.0.0.1:65432/",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid URL without port",
|
||||||
|
redirectURL: "not-a-valid-url",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port excluded even if not in use",
|
||||||
|
redirectURL: "http://127.0.0.1:8050/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
86
client/internal/auth/pkce_flow_windows.go
Normal file
86
client/internal/auth/pkce_flow_windows.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
ranges, err := getExcludedPortRangesFromNetsh()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||||
|
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("netsh command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseExcludedPortRanges(string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||||
|
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||||
|
var ranges []excludedPortRange
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||||
|
|
||||||
|
foundHeader := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
|
||||||
|
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||||
|
foundHeader = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundHeader {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(line, "----------") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort, err := strconv.Atoi(fields[0])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
endPort, err := strconv.Atoi(fields[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges, nil
|
||||||
|
}
|
||||||
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
netshOutput string
|
||||||
|
expectedRanges []excludedPortRange
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid netsh output with multiple ranges",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
5357 5357 *
|
||||||
|
50000 50059 *
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 5357, end: 5357},
|
||||||
|
{start: 50000, end: 50059},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty output",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
`,
|
||||||
|
expectedRanges: nil,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single range",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
8080 8090
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 8080, end: 8090},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedRanges, ranges)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||||
|
ranges := getSystemExcludedPortRanges()
|
||||||
|
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||||
|
|
||||||
|
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener1.Close()
|
||||||
|
}()
|
||||||
|
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
availablePort := 65432
|
||||||
|
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||||
|
},
|
||||||
|
UseIDToken: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, flow)
|
||||||
|
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||||
|
"Should skip port in use and select available port")
|
||||||
|
}
|
||||||
@@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
@@ -271,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
@@ -291,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := c.engine
|
|
||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if engine != nil && engine.wgInterface != nil {
|
// todo: consider to remove this condition. Is not thread safe.
|
||||||
|
// We should always call Stop(), but we need to verify that it is idempotent
|
||||||
|
if engine.wgInterface != nil {
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
@@ -416,20 +421,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
nm = *config.NetworkMonitor
|
nm = *config.NetworkMonitor
|
||||||
}
|
}
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
WgIfaceName: config.WgIface,
|
WgIfaceName: config.WgIface,
|
||||||
WgAddr: peerConfig.Address,
|
WgAddr: peerConfig.Address,
|
||||||
IFaceBlackList: config.IFaceBlackList,
|
IFaceBlackList: config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: config.WgPort,
|
WgPort: config.WgPort,
|
||||||
NetworkMonitor: nm,
|
NetworkMonitor: nm,
|
||||||
SSHKey: []byte(config.SSHKey),
|
SSHKey: []byte(config.SSHKey),
|
||||||
NATExternalIPs: config.NATExternalIPs,
|
NATExternalIPs: config.NATExternalIPs,
|
||||||
CustomDNSAddress: config.CustomDNSAddress,
|
CustomDNSAddress: config.CustomDNSAddress,
|
||||||
RosenpassEnabled: config.RosenpassEnabled,
|
RosenpassEnabled: config.RosenpassEnabled,
|
||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
EnableSSHRoot: config.EnableSSHRoot,
|
||||||
|
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||||
|
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||||
|
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||||
|
DisableSSHAuth: config.DisableSSHAuth,
|
||||||
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||||
@@ -515,6 +525,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
|
config.EnableSSHRoot,
|
||||||
|
config.EnableSSHSFTP,
|
||||||
|
config.EnableSSHLocalPortForwarding,
|
||||||
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -453,6 +453,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.ServerSSHAllowed != nil {
|
if g.internalConfig.ServerSSHAllowed != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||||
}
|
}
|
||||||
|
if g.internalConfig.EnableSSHRoot != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
|
||||||
|
}
|
||||||
|
if g.internalConfig.EnableSSHSFTP != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
|
||||||
|
}
|
||||||
|
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
|
||||||
|
}
|
||||||
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
|||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
|
if zone.SkipPTRProcess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
if record.Type != int(dns.TypeA) {
|
if record.Type != int(dns.TypeA) {
|
||||||
continue
|
continue
|
||||||
@@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
|||||||
records := collectPTRRecords(config, network)
|
records := collectPTRRecords(config, network)
|
||||||
|
|
||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
Records: records,
|
Records: records,
|
||||||
|
SearchDomainDisabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||||
|
|||||||
@@ -11,11 +11,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4ReverseZone = ".in-addr.arpa."
|
|
||||||
ipv6ReverseZone = ".ip6.arpa."
|
|
||||||
)
|
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: customZone.SearchDomainDisabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
newNet, err := stdnet.NewNet(nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create stdnet: %v", err)
|
t.Errorf("create stdnet: %v", err)
|
||||||
return
|
return
|
||||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create stdnet: %v", err)
|
t.Fatalf("create stdnet: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
timeoutMsg += " " + peerInfo
|
timeoutMsg += " " + peerInfo
|
||||||
}
|
}
|
||||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||||
logger.Warnf(timeoutMsg)
|
logger.Warn(timeoutMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
|
|||||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||||
|
for i, ip := range ips {
|
||||||
|
ips[i] = ip.Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -30,7 +29,6 @@ import (
|
|||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
@@ -51,10 +49,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -115,7 +113,12 @@ type EngineConfig struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
|
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
EnableSSHRoot *bool
|
||||||
|
EnableSSHSFTP *bool
|
||||||
|
EnableSSHLocalPortForwarding *bool
|
||||||
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
@@ -148,8 +151,6 @@ type Engine struct {
|
|||||||
|
|
||||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||||
syncMsgMux *sync.Mutex
|
syncMsgMux *sync.Mutex
|
||||||
// sshMux protects sshServer field access
|
|
||||||
sshMux sync.Mutex
|
|
||||||
|
|
||||||
config *EngineConfig
|
config *EngineConfig
|
||||||
mobileDep MobileDependency
|
mobileDep MobileDependency
|
||||||
@@ -175,8 +176,7 @@ type Engine struct {
|
|||||||
|
|
||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServer sshServer
|
||||||
sshServer nbssh.Server
|
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -246,7 +246,6 @@ func NewEngine(
|
|||||||
STUNs: []*stun.URI{},
|
STUNs: []*stun.URI{},
|
||||||
TURNs: []*stun.URI{},
|
TURNs: []*stun.URI{},
|
||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
@@ -256,7 +255,7 @@ func NewEngine(
|
|||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
|
|
||||||
path := sm.GetStatePath()
|
path := sm.GetStatePath()
|
||||||
if runtime.GOOS == "ios" {
|
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||||
if !fileExists(mobileDep.StateFilePath) {
|
if !fileExists(mobileDep.StateFilePath) {
|
||||||
err := createFile(mobileDep.StateFilePath)
|
err := createFile(mobileDep.StateFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -268,6 +267,7 @@ func NewEngine(
|
|||||||
path = mobileDep.StateFilePath
|
path = mobileDep.StateFilePath
|
||||||
}
|
}
|
||||||
engine.stateManager = statemanager.New(path)
|
engine.stateManager = statemanager.New(path)
|
||||||
|
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
return engine
|
return engine
|
||||||
@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
@@ -292,8 +291,11 @@ func (e *Engine) Stop() error {
|
|||||||
}
|
}
|
||||||
log.Info("Network monitor: stopped")
|
log.Info("Network monitor: stopped")
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
if err := e.stopSSHServer(); err != nil {
|
||||||
e.stopDNSServer()
|
log.Warnf("failed to stop SSH server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
@@ -302,24 +304,29 @@ func (e *Engine) Stop() error {
|
|||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
e.stopDNSForwarder()
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop(e.stateManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
if err := e.removeAllPeers(); err != nil {
|
if err := e.removeAllPeers(); err != nil {
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
log.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop(e.stateManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.stopDNSForwarder()
|
||||||
|
|
||||||
|
// stop/restore DNS after peers are closed but before interface goes down
|
||||||
|
// so dbus and friends don't complain because of a missing interface
|
||||||
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@@ -331,16 +338,18 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer stateCancel()
|
||||||
|
|
||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
log.Errorf("failed to stop state manager: %v", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
timeout := e.calculateShutdownTimeout()
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
@@ -426,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
err := e.rpManager.Run()
|
if err := e.rpManager.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -479,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -703,16 +712,10 @@ func (e *Engine) removeAllPeers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
// removePeer closes an existing peer connection and removes a peer
|
||||||
func (e *Engine) removePeer(peerKey string) error {
|
func (e *Engine) removePeer(peerKey string) error {
|
||||||
log.Debugf("removing peer from engine %s", peerKey)
|
log.Debugf("removing peer from engine %s", peerKey)
|
||||||
|
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
|
|
||||||
e.connMgr.RemovePeerConn(peerKey)
|
e.connMgr.RemovePeerConn(peerKey)
|
||||||
|
|
||||||
err := e.statusRecorder.RemovePeer(peerKey)
|
err := e.statusRecorder.RemovePeer(peerKey)
|
||||||
@@ -750,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
wCfg := update.GetNetbirdConfig()
|
wCfg := update.GetNetbirdConfig()
|
||||||
err := e.updateTURNs(wCfg.GetTurns())
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
@@ -789,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nm := update.GetNetworkMap()
|
nm := update.GetNetworkMap()
|
||||||
if nm == nil {
|
if nm == nil || update.SkipNetworkMapUpdate {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -884,6 +892,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.BlockLANAccess,
|
e.config.BlockLANAccess,
|
||||||
e.config.BlockInbound,
|
e.config.BlockInbound,
|
||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
|
e.config.EnableSSHRoot,
|
||||||
|
e.config.EnableSSHSFTP,
|
||||||
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -893,74 +906,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(server nbssh.Server) bool {
|
|
||||||
return server == nil || reflect.ValueOf(server).IsNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|
||||||
if e.config.BlockInbound {
|
|
||||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !e.config.ServerSSHAllowed {
|
|
||||||
log.Info("SSH server is not enabled")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if sshConf.GetSshEnabled() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
e.sshMux.Lock()
|
|
||||||
// start SSH server if it wasn't running
|
|
||||||
if isNil(e.sshServer) {
|
|
||||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
|
||||||
if nbnetstack.IsEnabled() {
|
|
||||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
|
||||||
}
|
|
||||||
// nil sshServer means it has not yet been started
|
|
||||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
|
||||||
if err != nil {
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
return fmt.Errorf("create ssh server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.sshServer = server
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// blocking
|
|
||||||
err = server.Start()
|
|
||||||
if err != nil {
|
|
||||||
// will throw error when we stop it even if it is a graceful stop
|
|
||||||
log.Debugf("stopped SSH server with error %v", err)
|
|
||||||
}
|
|
||||||
e.sshMux.Lock()
|
|
||||||
e.sshServer = nil
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
log.Infof("stopped SSH server")
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
log.Debugf("SSH server is already running")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
// Disable SSH server request, so stop it if it was running
|
|
||||||
err := e.sshServer.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to stop SSH server %v", err)
|
|
||||||
}
|
|
||||||
e.sshServer = nil
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return errors.New("wireguard interface is not initialized")
|
return errors.New("wireguard interface is not initialized")
|
||||||
@@ -973,8 +918,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if conf.GetSshConfig() != nil {
|
if conf.GetSshConfig() != nil {
|
||||||
err := e.updateSSH(conf.GetSshConfig())
|
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed handling SSH server setup: %v", err)
|
log.Warnf("failed handling SSH server setup: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1012,9 +956,14 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.BlockLANAccess,
|
e.config.BlockLANAccess,
|
||||||
e.config.BlockInbound,
|
e.config.BlockInbound,
|
||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
|
e.config.EnableSSHRoot,
|
||||||
|
e.config.EnableSSHSFTP,
|
||||||
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1170,19 +1119,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
// update SSHServer by adding remote peer SSH keys
|
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||||
for _, config := range networkMap.GetRemotePeers() {
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
|
||||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
e.sshMux.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
@@ -1259,6 +1200,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||||
|
//nolint
|
||||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||||
if forwarderPort == 0 {
|
if forwarderPort == 0 {
|
||||||
forwarderPort = nbdns.ForwarderClientPort
|
forwarderPort = nbdns.ForwarderClientPort
|
||||||
@@ -1273,7 +1215,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
|||||||
|
|
||||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||||
dnsZone := nbdns.CustomZone{
|
dnsZone := nbdns.CustomZone{
|
||||||
Domain: zone.GetDomain(),
|
Domain: zone.GetDomain(),
|
||||||
|
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||||
|
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||||
}
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
dnsRecord := nbdns.SimpleRecord{
|
dnsRecord := nbdns.SimpleRecord{
|
||||||
@@ -1433,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
@@ -1544,15 +1493,6 @@ func (e *Engine) close() {
|
|||||||
e.statusRecorder.SetWgIface(nil)
|
e.statusRecorder.SetWgIface(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.sshMux.Lock()
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
err := e.sshServer.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed stopping the SSH server: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
e.sshMux.Unlock()
|
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Close(e.stateManager)
|
err := e.firewall.Close(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1583,6 +1523,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.BlockLANAccess,
|
e.config.BlockLANAccess,
|
||||||
e.config.BlockInbound,
|
e.config.BlockInbound,
|
||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
|
e.config.EnableSSHRoot,
|
||||||
|
e.config.EnableSSHSFTP,
|
||||||
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
@@ -1901,6 +1846,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
|||||||
return e.wgInterface.Address().IP
|
return e.wgInterface.Address().IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) RenewTun(fd int) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
wgInterface := e.wgInterface
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if wgInterface == nil {
|
||||||
|
return fmt.Errorf("wireguard interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgInterface.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|
||||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
func (e *Engine) updateDNSForwarder(
|
func (e *Engine) updateDNSForwarder(
|
||||||
enabled bool,
|
enabled bool,
|
||||||
|
|||||||
355
client/internal/engine_ssh.go
Normal file
355
client/internal/engine_ssh.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sshServer interface {
|
||||||
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
|
Stop() error
|
||||||
|
GetStatus() (bool, []sshserver.SessionInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||||
|
return fmt.Errorf("add SSH port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||||
|
return e.stopSSHServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.config.ServerSSHAllowed {
|
||||||
|
log.Info("SSH server is disabled in config")
|
||||||
|
return e.stopSSHServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sshConf.GetSshEnabled() {
|
||||||
|
if e.config.ServerSSHAllowed {
|
||||||
|
log.Info("SSH server is locally allowed but disabled by management server")
|
||||||
|
}
|
||||||
|
return e.stopSSHServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.sshServer != nil {
|
||||||
|
log.Debug("SSH server is already running")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||||
|
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||||
|
return e.startSSHServer(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||||
|
jwtConfig := &sshserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audience: protoJWT.GetAudience(),
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.startSSHServer(jwtConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("SSH server requires valid JWT configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
|
if len(peerInfo) == 0 {
|
||||||
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
configMgr := sshconfig.New()
|
||||||
|
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||||
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
|
return nil // Don't fail engine startup on SSH config issues
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||||
|
|
||||||
|
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||||
|
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||||
|
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||||
|
}); err != nil {
|
||||||
|
log.Warnf("failed to update SSH config state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||||
|
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||||
|
var peerInfo []sshconfig.PeerSSHInfo
|
||||||
|
|
||||||
|
for _, peerConfig := range remotePeers {
|
||||||
|
if peerConfig.GetSshConfig() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||||
|
if len(sshPubKeyBytes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIP := e.extractPeerIP(peerConfig)
|
||||||
|
hostname := e.extractHostname(peerConfig)
|
||||||
|
|
||||||
|
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||||
|
Hostname: hostname,
|
||||||
|
IP: peerIP,
|
||||||
|
FQDN: peerConfig.GetFqdn(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||||
|
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||||
|
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||||
|
return prefix.Addr().String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractHostname extracts short hostname from FQDN
|
||||||
|
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||||
|
fqdn := peerConfig.GetFqdn()
|
||||||
|
if fqdn == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(fqdn, ".")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||||
|
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||||
|
for _, peerConfig := range remotePeers {
|
||||||
|
if peerConfig.GetSshConfig() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||||
|
if len(sshPubKeyBytes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
|
||||||
|
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||||
|
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
statusRecorder := e.statusRecorder
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if statusRecorder == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus := statusRecorder.GetFullStatus()
|
||||||
|
for _, peerState := range fullStatus.Peers {
|
||||||
|
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||||
|
if len(peerState.SSHHostKey) > 0 {
|
||||||
|
return peerState.SSHHostKey, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
log.Warnf("failed to remove SSH client config: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("SSH client config cleanup completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||||
|
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wg interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConfig := &sshserver.Config{
|
||||||
|
HostKeyPEM: e.config.SSHKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
|
|
||||||
|
wgAddr := e.wgInterface.Address()
|
||||||
|
server.SetNetworkValidation(wgAddr)
|
||||||
|
|
||||||
|
netbirdIP := wgAddr.IP
|
||||||
|
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
server.SetNetstackNet(netstackNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.configureSSHServer(server)
|
||||||
|
|
||||||
|
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||||
|
return fmt.Errorf("start SSH server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.sshServer = server
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||||
|
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.setupSSHPortRedirection(); err != nil {
|
||||||
|
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureSSHServer applies SSH configuration options to the server.
|
||||||
|
func (e *Engine) configureSSHServer(server *sshserver.Server) {
|
||||||
|
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
log.Info("SSH root login enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowRootLogin(false)
|
||||||
|
log.Info("SSH root login disabled (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
|
||||||
|
server.SetAllowSFTP(true)
|
||||||
|
log.Info("SSH SFTP subsystem enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowSFTP(false)
|
||||||
|
log.Info("SSH SFTP subsystem disabled (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
|
||||||
|
server.SetAllowLocalPortForwarding(true)
|
||||||
|
log.Info("SSH local port forwarding enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowLocalPortForwarding(false)
|
||||||
|
log.Info("SSH local port forwarding disabled (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
|
||||||
|
server.SetAllowRemotePortForwarding(true)
|
||||||
|
log.Info("SSH remote port forwarding enabled")
|
||||||
|
} else {
|
||||||
|
server.SetAllowRemotePortForwarding(false)
|
||||||
|
log.Info("SSH remote port forwarding disabled (default)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) cleanupSSHPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||||
|
return fmt.Errorf("remove SSH port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopSSHServer() error {
|
||||||
|
if e.sshServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.cleanupSSHPortRedirection(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup SSH port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||||
|
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("stopping SSH server")
|
||||||
|
err := e.sshServer.Stop()
|
||||||
|
e.sshServer = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSSHServerStatus returns the SSH server status and active sessions
|
||||||
|
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
sshServer := e.sshServer
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if sshServer == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshServer.GetStatus()
|
||||||
|
}
|
||||||
@@ -7,5 +7,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package internal
|
|||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
79
client/internal/engine_sync_test.go
Normal file
79
client/internal/engine_sync_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
|
||||||
|
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun199",
|
||||||
|
WgAddr: "100.70.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
// Precondition
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
|
||||||
|
SkipNetworkMapUpdate: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when NetworkMap is nil
|
||||||
|
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun198",
|
||||||
|
WgAddr: "100.70.0.2/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33101,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -25,11 +24,18 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -43,13 +49,12 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -105,6 +110,10 @@ type MockWGIface struct {
|
|||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) RenewTun(_ int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -211,11 +220,13 @@ func TestMain(m *testing.M) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_SSH(t *testing.T) {
|
func TestEngine_SSH(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
t.Skip("skipping TestEngine_SSH")
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@@ -237,6 +248,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
ServerSSHAllowed: true,
|
ServerSSHAllowed: true,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
|
SSHKey: sshKey,
|
||||||
},
|
},
|
||||||
MobileDependency{},
|
MobileDependency{},
|
||||||
peer.NewRecorder("https://mgm"),
|
peer.NewRecorder("https://mgm"),
|
||||||
@@ -247,35 +259,8 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
var sshKeysAdded []string
|
|
||||||
var sshPeersRemoved []string
|
|
||||||
|
|
||||||
sshCtx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
|
|
||||||
return &ssh.MockServer{
|
|
||||||
Ctx: sshCtx,
|
|
||||||
StopFunc: func() error {
|
|
||||||
cancel()
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
StartFunc: func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
return ctx.Err()
|
|
||||||
},
|
|
||||||
AddAuthorizedKeyFunc: func(peer, newKey string) error {
|
|
||||||
sshKeysAdded = append(sshKeysAdded, newKey)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
RemoveAuthorizedKeyFunc: func(peer string) {
|
|
||||||
sshPeersRemoved = append(sshPeersRemoved, peer)
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
err = engine.Start(nil, nil)
|
err = engine.Start(nil, nil)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := engine.Stop()
|
err := engine.Stop()
|
||||||
@@ -301,9 +286,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
@@ -311,19 +294,24 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
Serial: 7,
|
Serial: 7,
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshEnabled: true,
|
||||||
|
JwtConfig: &mgmtProto.JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
},
|
||||||
|
}},
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
assert.NotNil(t, engine.sshServer)
|
assert.NotNil(t, engine.sshServer)
|
||||||
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
|
|
||||||
|
|
||||||
// now remove peer
|
// now remove peer
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
@@ -333,13 +321,10 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
// time.Sleep(250 * time.Millisecond)
|
||||||
assert.NotNil(t, engine.sshServer)
|
assert.NotNil(t, engine.sshServer)
|
||||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
|
||||||
|
|
||||||
// now disable SSH server
|
// now disable SSH server
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
@@ -351,12 +336,70 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||||
|
// Test that SSH server start/stop logic works based on config
|
||||||
|
engine := &Engine{
|
||||||
|
config: &EngineConfig{
|
||||||
|
ServerSSHAllowed: false, // Start with SSH disabled
|
||||||
|
},
|
||||||
|
syncMsgMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SSH disabled config
|
||||||
|
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
|
||||||
|
err := engine.updateSSH(sshConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// Test inbound blocked
|
||||||
|
engine.config.BlockInbound = true
|
||||||
|
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
engine.config.BlockInbound = false
|
||||||
|
|
||||||
|
// Test with server SSH not allowed
|
||||||
|
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_SSHServerConsistency(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("server set only on successful creation", func(t *testing.T) {
|
||||||
|
engine := &Engine{
|
||||||
|
config: &EngineConfig{
|
||||||
|
ServerSSHAllowed: true,
|
||||||
|
SSHKey: []byte("test-key"),
|
||||||
|
},
|
||||||
|
syncMsgMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.wgInterface = nil
|
||||||
|
|
||||||
|
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
|
||||||
|
engine := &Engine{
|
||||||
|
config: &EngineConfig{
|
||||||
|
ServerSSHAllowed: false,
|
||||||
|
},
|
||||||
|
syncMsgMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := engine.stopSSHServer()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||||
@@ -588,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
defer close(updates)
|
defer close(updates)
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
for msg := range updates {
|
for msg := range updates {
|
||||||
err := msgHandler(msg)
|
err := msgHandler(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -771,7 +814,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -974,7 +1017,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1556,7 +1599,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -1584,13 +1626,19 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
|
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
type wgIfaceBase interface {
|
type wgIfaceBase interface {
|
||||||
Create() error
|
Create() error
|
||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||||
|
RenewTun(fd int) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
|||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
|
config.EnableSSHRoot,
|
||||||
|
config.EnableSSHSFTP,
|
||||||
|
config.EnableSSHLocalPortForwarding,
|
||||||
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
return serverKey, loginResp, err
|
return serverKey, loginResp, err
|
||||||
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
config.BlockLANAccess,
|
config.BlockLANAccess,
|
||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
|
config.EnableSSHRoot,
|
||||||
|
config.EnableSSHSFTP,
|
||||||
|
config.EnableSSHLocalPortForwarding,
|
||||||
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,5 +11,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func isForceRelayed() bool {
|
func isForceRelayed() bool {
|
||||||
|
if runtime.GOOS == "js" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||||
log.Debugf("Gathering ICE candidates")
|
log.Debugf("Gathering ICE candidates")
|
||||||
|
|
||||||
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,6 +23,8 @@ const (
|
|||||||
iceFailedTimeoutDefault = 6 * time.Second
|
iceFailedTimeoutDefault = 6 * time.Second
|
||||||
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||||
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||||
|
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
|
||||||
|
iceAgentCloseTimeout = 3 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type ThreadSafeAgent struct {
|
type ThreadSafeAgent struct {
|
||||||
@@ -32,18 +35,28 @@ type ThreadSafeAgent struct {
|
|||||||
func (a *ThreadSafeAgent) Close() error {
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
var err error
|
var err error
|
||||||
a.once.Do(func() {
|
a.once.Do(func() {
|
||||||
err = a.Agent.Close()
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- a.Agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
iceFailedTimeout := iceFailedTimeout()
|
iceFailedTimeout := iceFailedTimeout()
|
||||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
|
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,11 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(ifaceBlacklist)
|
return stdnet.NewNet(ctx, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
)
|
||||||
|
|
||||||
|
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventQueueSize = 10
|
const eventQueueSize = 10
|
||||||
@@ -67,6 +67,7 @@ type State struct {
|
|||||||
BytesRx int64
|
BytesRx int64
|
||||||
Latency time.Duration
|
Latency time.Duration
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
|
SSHHostKey []byte
|
||||||
routes map[string]struct{}
|
routes map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||||
|
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peerPubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.SSHHostKey = sshHostKey
|
||||||
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// FinishPeerListModifications this event invoke the notification
|
// FinishPeerListModifications this event invoke the notification
|
||||||
func (d *Status) FinishPeerListModifications() {
|
func (d *Status) FinishPeerListModifications() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||||
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create agent: %w", err)
|
return nil, fmt.Errorf("create agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{
|
|||||||
|
|
||||||
// ConfigInput carries configuration changes to the client
|
// ConfigInput carries configuration changes to the client
|
||||||
type ConfigInput struct {
|
type ConfigInput struct {
|
||||||
ManagementURL string
|
ManagementURL string
|
||||||
AdminURL string
|
AdminURL string
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
StateFilePath string
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
NATExternalIPs []string
|
EnableSSHRoot *bool
|
||||||
CustomDNSAddress []byte
|
EnableSSHSFTP *bool
|
||||||
RosenpassEnabled *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
RosenpassPermissive *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
InterfaceName *string
|
DisableSSHAuth *bool
|
||||||
WireguardPort *int
|
SSHJWTCacheTTL *int
|
||||||
NetworkMonitor *bool
|
NATExternalIPs []string
|
||||||
DisableAutoConnect *bool
|
CustomDNSAddress []byte
|
||||||
ExtraIFaceBlackList []string
|
RosenpassEnabled *bool
|
||||||
DNSRouteInterval *time.Duration
|
RosenpassPermissive *bool
|
||||||
ClientCertPath string
|
InterfaceName *string
|
||||||
ClientCertKeyPath string
|
WireguardPort *int
|
||||||
|
NetworkMonitor *bool
|
||||||
|
DisableAutoConnect *bool
|
||||||
|
ExtraIFaceBlackList []string
|
||||||
|
DNSRouteInterval *time.Duration
|
||||||
|
ClientCertPath string
|
||||||
|
ClientCertKeyPath string
|
||||||
|
|
||||||
DisableClientRoutes *bool
|
DisableClientRoutes *bool
|
||||||
DisableServerRoutes *bool
|
DisableServerRoutes *bool
|
||||||
@@ -82,18 +88,24 @@ type ConfigInput struct {
|
|||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// Wireguard private key of local peer
|
// Wireguard private key of local peer
|
||||||
PrivateKey string
|
PrivateKey string
|
||||||
PreSharedKey string
|
PreSharedKey string
|
||||||
ManagementURL *url.URL
|
ManagementURL *url.URL
|
||||||
AdminURL *url.URL
|
AdminURL *url.URL
|
||||||
WgIface string
|
WgIface string
|
||||||
WgPort int
|
WgPort int
|
||||||
NetworkMonitor *bool
|
NetworkMonitor *bool
|
||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
EnableSSHRoot *bool
|
||||||
|
EnableSSHSFTP *bool
|
||||||
|
EnableSSHLocalPortForwarding *bool
|
||||||
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
SSHJWTCacheTTL *int
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||||
|
if *input.EnableSSHRoot {
|
||||||
|
log.Infof("enabling SSH root login")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH root login")
|
||||||
|
}
|
||||||
|
config.EnableSSHRoot = input.EnableSSHRoot
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||||
|
if *input.EnableSSHSFTP {
|
||||||
|
log.Infof("enabling SSH SFTP subsystem")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH SFTP subsystem")
|
||||||
|
}
|
||||||
|
config.EnableSSHSFTP = input.EnableSSHSFTP
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||||
|
if *input.EnableSSHLocalPortForwarding {
|
||||||
|
log.Infof("enabling SSH local port forwarding")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH local port forwarding")
|
||||||
|
}
|
||||||
|
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||||
|
if *input.EnableSSHRemotePortForwarding {
|
||||||
|
log.Infof("enabling SSH remote port forwarding")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling SSH remote port forwarding")
|
||||||
|
}
|
||||||
|
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||||
|
if *input.DisableSSHAuth {
|
||||||
|
log.Infof("disabling SSH authentication")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling SSH authentication")
|
||||||
|
}
|
||||||
|
config.DisableSSHAuth = input.DisableSSHAuth
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||||
|
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||||
|
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||||
|
|||||||
@@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) {
|
|||||||
|
|
||||||
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
wireguardPort *int
|
wireguardPort *int
|
||||||
expectedPort int
|
expectedPort int
|
||||||
description string
|
description string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "no port specified uses default",
|
name: "no port specified uses default",
|
||||||
|
|||||||
@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLoginHint retrieves the email from the active profile to use as login_hint.
|
||||||
|
func GetLoginHint() string {
|
||||||
|
pm := NewProfileManager()
|
||||||
|
activeProf, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get active profile for login hint: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return profileState.Email
|
||||||
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(nil)
|
net, err := stdnet.NewNet(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(nil)
|
net, err := stdnet.NewNet(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
@@ -39,6 +38,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
|||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
|
|||||||
218
client/internal/sleep/detector_darwin.go
Normal file
218
client/internal/sleep/detector_darwin.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||||
|
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||||
|
#include <IOKit/IOMessage.h>
|
||||||
|
#include <CoreFoundation/CoreFoundation.h>
|
||||||
|
|
||||||
|
extern void sleepCallbackBridge();
|
||||||
|
extern void poweredOnCallbackBridge();
|
||||||
|
extern void suspendedCallbackBridge();
|
||||||
|
extern void resumedCallbackBridge();
|
||||||
|
|
||||||
|
|
||||||
|
// C global variables for IOKit state
|
||||||
|
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||||
|
static io_object_t g_notifierObject = 0;
|
||||||
|
static io_object_t g_generalInterestNotifier = 0;
|
||||||
|
static io_connect_t g_rootPort = 0;
|
||||||
|
static CFRunLoopRef g_runLoop = NULL;
|
||||||
|
|
||||||
|
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||||
|
switch (messageType) {
|
||||||
|
case kIOMessageSystemWillSleep:
|
||||||
|
sleepCallbackBridge();
|
||||||
|
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||||
|
break;
|
||||||
|
case kIOMessageSystemHasPoweredOn:
|
||||||
|
poweredOnCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsSuspended:
|
||||||
|
suspendedCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsResumed:
|
||||||
|
resumedCallbackBridge();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNotifications() {
|
||||||
|
g_rootPort = IORegisterForSystemPower(
|
||||||
|
NULL,
|
||||||
|
&g_notifyPortRef,
|
||||||
|
(IOServiceInterestCallback)sleepCallback,
|
||||||
|
&g_notifierObject
|
||||||
|
);
|
||||||
|
|
||||||
|
if (g_rootPort == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
g_runLoop = CFRunLoopGetCurrent();
|
||||||
|
CFRunLoopRun();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unregisterNotifications() {
|
||||||
|
CFRunLoopRemoveSource(g_runLoop,
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
IODeregisterForSystemPower(&g_notifierObject);
|
||||||
|
IOServiceClose(g_rootPort);
|
||||||
|
IONotificationPortDestroy(g_notifyPortRef);
|
||||||
|
CFRunLoopStop(g_runLoop);
|
||||||
|
|
||||||
|
g_notifyPortRef = NULL;
|
||||||
|
g_notifierObject = 0;
|
||||||
|
g_rootPort = 0;
|
||||||
|
g_runLoop = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serviceRegistry = make(map[*Detector]struct{})
|
||||||
|
serviceRegistryMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
//export sleepCallbackBridge
|
||||||
|
func sleepCallbackBridge() {
|
||||||
|
log.Info("sleepCallbackBridge event triggered")
|
||||||
|
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeSleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export resumedCallbackBridge
|
||||||
|
func resumedCallbackBridge() {
|
||||||
|
log.Info("resumedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export suspendedCallbackBridge
|
||||||
|
func suspendedCallbackBridge() {
|
||||||
|
log.Info("suspendedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export poweredOnCallbackBridge
|
||||||
|
func poweredOnCallbackBridge() {
|
||||||
|
log.Info("poweredOnCallbackBridge event triggered")
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeWakeUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
callback func(event EventType)
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector() (*Detector, error) {
|
||||||
|
return &Detector{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Register(callback func(event EventType)) error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := serviceRegistry[d]; exists {
|
||||||
|
return fmt.Errorf("detector service already registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
d.callback = callback
|
||||||
|
|
||||||
|
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
|
||||||
|
// CFRunLoop must run on a single fixed OS thread
|
||||||
|
go func() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.registerNotifications()
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Info("sleep detection service started on macOS")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||||
|
// and the runloop is stopped and cleaned up.
|
||||||
|
func (d *Detector) Deregister() error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
_, exists := serviceRegistry[d]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel and remove this detector
|
||||||
|
d.cancel()
|
||||||
|
delete(serviceRegistry, d)
|
||||||
|
|
||||||
|
// If other Detectors still exist, leave IOKit running
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service stopping (deregister)")
|
||||||
|
|
||||||
|
// Deregister IOKit notifications, stop runloop, and free resources
|
||||||
|
C.unregisterNotifications()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) triggerCallback(event EventType) {
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
|
timeout := time.NewTimer(500 * time.Millisecond)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
cb := d.callback
|
||||||
|
go func(callback func(event EventType)) {
|
||||||
|
log.Info("sleep detection event fired")
|
||||||
|
callback(event)
|
||||||
|
close(doneChan)
|
||||||
|
}(cb)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
case <-timeout.C:
|
||||||
|
log.Warnf("sleep callback timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
9
client/internal/sleep/detector_notsupported.go
Normal file
9
client/internal/sleep/detector_notsupported.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func NewDetector() (detector, error) {
|
||||||
|
return nil, fmt.Errorf("sleep not supported on this platform")
|
||||||
|
}
|
||||||
37
client/internal/sleep/service.go
Normal file
37
client/internal/sleep/service.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package sleep
|
||||||
|
|
||||||
|
var (
|
||||||
|
EventTypeUnknown EventType = 0
|
||||||
|
EventTypeSleep EventType = 1
|
||||||
|
EventTypeWakeUp EventType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type EventType int
|
||||||
|
|
||||||
|
type detector interface {
|
||||||
|
Register(callback func(eventType EventType)) error
|
||||||
|
Deregister() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
detector detector
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() (*Service, error) {
|
||||||
|
d, err := NewDetector()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
detector: d,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Register(callback func(eventType EventType)) error {
|
||||||
|
return s.detector.Register(callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Deregister() error {
|
||||||
|
return s.detector.Deregister()
|
||||||
|
}
|
||||||
@@ -4,17 +4,28 @@
|
|||||||
package stdnet
|
package stdnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
const updateInterval = 30 * time.Second
|
const (
|
||||||
|
updateInterval = 30 * time.Second
|
||||||
|
dnsResolveTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoSuitableAddress = errors.New("no suitable address found")
|
||||||
|
|
||||||
// Net is an implementation of the net.Net interface
|
// Net is an implementation of the net.Net interface
|
||||||
// based on functions of the standard net package.
|
// based on functions of the standard net package.
|
||||||
@@ -28,12 +39,19 @@ type Net struct {
|
|||||||
|
|
||||||
// mu is shared between interfaces and lastUpdate
|
// mu is shared between interfaces and lastUpdate
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// ctx is the context for network operations that supports cancellation
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetWithDiscover creates a new StdNet instance.
|
// NewNetWithDiscover creates a new StdNet instance.
|
||||||
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
n := &Net{
|
n := &Net{
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
||||||
// so in android cli use pionDiscover
|
// so in android cli use pionDiscover
|
||||||
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewNet creates a new StdNet instance.
|
// NewNet creates a new StdNet instance.
|
||||||
func NewNet(disallowList []string) (*Net, error) {
|
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
n := &Net{
|
n := &Net{
|
||||||
iFaceDiscover: pionDiscover{},
|
iFaceDiscover: pionDiscover{},
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
return n, n.UpdateInterfaces()
|
return n, n.UpdateInterfaces()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveAddr performs DNS resolution with context support and timeout.
|
||||||
|
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
|
||||||
|
host, portStr, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
|
||||||
|
}
|
||||||
|
if port < 0 || port > 65535 {
|
||||||
|
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet := "ip"
|
||||||
|
switch network {
|
||||||
|
case "tcp4", "udp4":
|
||||||
|
ipNet = "ip4"
|
||||||
|
case "tcp6", "udp6":
|
||||||
|
ipNet = "ip6"
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
addr := netip.IPv4Unspecified()
|
||||||
|
if ipNet == "ip6" {
|
||||||
|
addr = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(addr, uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return netip.AddrPort{}, errNoSuitableAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateInterfaces updates the internal list of network interfaces
|
// UpdateInterfaces updates the internal list of network interfaces
|
||||||
// and associated addresses filtering them by name.
|
// and associated addresses filtering them by name.
|
||||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||||
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
|
||||||
|
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
case "":
|
||||||
|
network = "udp"
|
||||||
|
default:
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := n.resolveAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.UDPAddrFromAddrPort(addrPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
|
||||||
|
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
case "":
|
||||||
|
network = "tcp"
|
||||||
|
default:
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := n.resolveAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.TCPAddrFromAddrPort(addrPort), nil
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package templates
|
||||||
|
|
||||||
|
import (
|
||||||
|
"html/template"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]string
|
||||||
|
outputFile string
|
||||||
|
expectedTitle string
|
||||||
|
expectedInContent []string
|
||||||
|
notExpectedInContent []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error_state",
|
||||||
|
data: map[string]string{
|
||||||
|
"Error": "authentication failed: invalid state",
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-error.html",
|
||||||
|
expectedTitle: "Login Failed",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"authentication failed: invalid state",
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success_state",
|
||||||
|
data: map[string]string{
|
||||||
|
// No error field means success
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-success.html",
|
||||||
|
expectedTitle: "Login Successful",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird. You can now close this window.",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_state_timeout",
|
||||||
|
data: map[string]string{
|
||||||
|
"Error": "authentication timeout: request expired after 5 minutes",
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-timeout.html",
|
||||||
|
expectedTitle: "Login Failed",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"authentication timeout: request expired after 5 minutes",
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temp directory for this test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, tt.outputFile)
|
||||||
|
|
||||||
|
// Create output file
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the template
|
||||||
|
if err := tmpl.Execute(file, tt.data); err != nil {
|
||||||
|
file.Close()
|
||||||
|
t.Fatalf("Failed to execute template: %v", err)
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
|
||||||
|
t.Logf("Generated test output: %s", outputPath)
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(content)
|
||||||
|
|
||||||
|
// Verify file has content
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Output file is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify basic HTML structure
|
||||||
|
basicElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"NetBird",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range basicElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected title
|
||||||
|
if !contains(contentStr, tt.expectedTitle) {
|
||||||
|
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected content is present
|
||||||
|
for _, expected := range tt.expectedInContent {
|
||||||
|
if !contains(contentStr, expected) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify unexpected content is not present
|
||||||
|
for _, notExpected := range tt.notExpectedInContent {
|
||||||
|
if contains(contentStr, notExpected) {
|
||||||
|
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
|
||||||
|
// Test that the template can be parsed without errors
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Template parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty data
|
||||||
|
t.Run("empty_data", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "empty-data.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
if err := tmpl.Execute(file, nil); err != nil {
|
||||||
|
t.Errorf("Template execution with nil data failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with error data
|
||||||
|
t.Run("with_error", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "with-error.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data := map[string]string{
|
||||||
|
"Error": "test error message",
|
||||||
|
}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Errorf("Template execution with error data failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
|
||||||
|
// Test that the template contains expected elements
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Template parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("success_content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "success.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data := map[string]string{}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Fatalf("Template execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file and verify it contains expected content
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for success indicators
|
||||||
|
contentStr := string(content)
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Generated HTML is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic HTML structure checks
|
||||||
|
requiredElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"Login Successful",
|
||||||
|
"NetBird",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range requiredElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error_content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "error.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
errorMsg := "test error message"
|
||||||
|
data := map[string]string{
|
||||||
|
"Error": errorMsg,
|
||||||
|
}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Fatalf("Template execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file and verify it contains expected content
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for error indicators
|
||||||
|
contentStr := string(content)
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Generated HTML is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic HTML structure checks
|
||||||
|
requiredElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"Login Failed",
|
||||||
|
errorMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range requiredElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||||
|
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -20,8 +23,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(fd int32, interfaceName string) error {
|
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||||
|
exportEnvList(envList)
|
||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
@@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
|
|||||||
}
|
}
|
||||||
return netIDs
|
return netIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func exportEnvList(list *EnvList) {
|
||||||
|
if list == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k, v := range list.AllItems() {
|
||||||
|
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
|
||||||
|
log.Debugf("Setting env variable %s: %s", k, v)
|
||||||
|
|
||||||
|
if err := os.Setenv(k, v); err != nil {
|
||||||
|
log.Errorf("could not set env variable %s: %v", k, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Env variable %s was set successfully", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
client/ios/NetBirdSDK/env_list.go
Normal file
34
client/ios/NetBirdSDK/env_list.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package NetBirdSDK
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
// EnvList is an exported struct to be bound by gomobile
|
||||||
|
type EnvList struct {
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvList creates a new EnvList
|
||||||
|
func NewEnvList() *EnvList {
|
||||||
|
return &EnvList{data: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair
|
||||||
|
func (el *EnvList) Put(key, value string) {
|
||||||
|
el.data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key
|
||||||
|
func (el *EnvList) Get(key string) string {
|
||||||
|
return el.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *EnvList) AllItems() map[string]string {
|
||||||
|
return el.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
|
||||||
|
func GetEnvKeyNBForceRelay() string {
|
||||||
|
return peer.EnvKeyNBForceRelay
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import _ "golang.org/x/mobile/bind"
|
import _ "golang.org/x/mobile/bind"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user