mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-07 16:39:55 +00:00
Compare commits
71 Commits
move-licen
...
sync-clien
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72513d7522 | ||
|
|
a1f1bf1f19 | ||
|
|
b5dec3df39 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
d2e48d4f5e | ||
|
|
27dd97c9c4 | ||
|
|
e87b4ace11 | ||
|
|
a232cf614c | ||
|
|
a293f760af | ||
|
|
10e9cf8c62 | ||
|
|
7193bd2da7 | ||
|
|
52948ccd61 | ||
|
|
4b77359042 | ||
|
|
387d43bcc1 | ||
|
|
e47d815dd2 | ||
|
|
cb83b7c0d3 | ||
|
|
ddcd182859 | ||
|
|
aca0398105 | ||
|
|
02200d790b | ||
|
|
f31bba87b4 | ||
|
|
7285fef0f0 | ||
|
|
20973063d8 | ||
|
|
ba2e9b6d88 | ||
|
|
131d7a3694 | ||
|
|
290fe2d8b9 | ||
|
|
7fb1a2fe31 | ||
|
|
32146e576d | ||
|
|
1311364397 | ||
|
|
68f56b797d | ||
|
|
3351b38434 | ||
|
|
05cbead39b | ||
|
|
60f4d5f9b0 | ||
|
|
4eeb2d8deb | ||
|
|
d71a82769c | ||
|
|
0d79301141 | ||
|
|
e4b41d0ad7 | ||
|
|
9cc9462cd5 | ||
|
|
3176b53968 | ||
|
|
27957036c9 | ||
|
|
6fb568728f | ||
|
|
cc97cffff1 | ||
|
|
c28275611b | ||
|
|
56f169eede | ||
|
|
07cf9d5895 | ||
|
|
7df49e249d | ||
|
|
dbfc8a52c9 | ||
|
|
98ddac07bf | ||
|
|
48475ddc05 | ||
|
|
6aa4ba7af4 | ||
|
|
2e16c9914a | ||
|
|
5c29d395b2 | ||
|
|
229e0038ee | ||
|
|
75327d9519 | ||
|
|
20f5f00635 | ||
|
|
fc141cf3a3 | ||
|
|
d0c65fa08e | ||
|
|
f241bfa339 | ||
|
|
4b2cd97d5f |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Running pre-push hook..."
|
||||||
|
if ! make lint; then
|
||||||
|
echo ""
|
||||||
|
echo "Hint: To push without verification, run:"
|
||||||
|
echo " git push --no-verify"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All checks passed!"
|
||||||
116
.github/workflows/check-license-dependencies.yml
vendored
116
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,40 +3,108 @@ name: Check License Dependencies
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- '.github/workflows/check-license-dependencies.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- '.github/workflows/check-license-dependencies.yml'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check-dependencies:
|
check-internal-dependencies:
|
||||||
|
name: Check Internal AGPL Dependencies
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Check for problematic license dependencies
|
||||||
|
run: |
|
||||||
|
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Find all directories except the problematic ones and system dirs
|
||||||
|
FOUND_ISSUES=0
|
||||||
|
while IFS= read -r dir; do
|
||||||
|
echo "=== Checking $dir ==="
|
||||||
|
# Search for problematic imports, excluding test files
|
||||||
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
|
if [ -n "$RESULTS" ]; then
|
||||||
|
echo "❌ Found problematic dependencies:"
|
||||||
|
echo "$RESULTS"
|
||||||
|
FOUND_ISSUES=1
|
||||||
|
else
|
||||||
|
echo "✓ No problematic dependencies found"
|
||||||
|
fi
|
||||||
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
|
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||||
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
echo "✅ All internal license dependencies are clean"
|
||||||
|
fi
|
||||||
|
|
||||||
|
check-external-licenses:
|
||||||
|
name: Check External GPL/AGPL Licenses
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: 'go.mod'
|
||||||
|
cache: true
|
||||||
|
|
||||||
|
- name: Install go-licenses
|
||||||
|
run: go install github.com/google/go-licenses@v1.6.0
|
||||||
|
|
||||||
|
- name: Check for GPL/AGPL licensed dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||||
FOUND_ISSUES=0
|
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||||
while IFS= read -r dir; do
|
|
||||||
echo "=== Checking $dir ==="
|
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||||
# Search for problematic imports, excluding test files
|
echo "Found copyleft licensed dependencies:"
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
echo "$COPYLEFT_DEPS"
|
||||||
if [ -n "$RESULTS" ]; then
|
echo ""
|
||||||
echo "❌ Found problematic dependencies:"
|
|
||||||
echo "$RESULTS"
|
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||||
FOUND_ISSUES=1
|
INCOMPATIBLE=""
|
||||||
else
|
while IFS=',' read -r package url license; do
|
||||||
echo "✓ No problematic dependencies found"
|
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||||
|
# Find ALL packages that import this GPL package using go list
|
||||||
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
|
# Check if any importer is NOT in management/signal/relay
|
||||||
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||||
|
|
||||||
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||||
|
else
|
||||||
|
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done <<< "$COPYLEFT_DEPS"
|
||||||
|
|
||||||
|
if [ -n "$INCOMPATIBLE" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||||
|
echo -e "$INCOMPATIBLE"
|
||||||
|
exit 1
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
echo "✅ All license dependencies are clean"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||||
|
|||||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,13 +15,14 @@ jobs:
|
|||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
release: "14.2"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y curl pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
curl -vLO "$GO_URL"
|
curl -vLO "$GO_URL"
|
||||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|||||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
@@ -106,15 +106,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -151,15 +151,15 @@ jobs:
|
|||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
id: go-env
|
id: go-env
|
||||||
run: |
|
run: |
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
-e CONTAINER=${CONTAINER} \
|
-e CONTAINER=${CONTAINER} \
|
||||||
golang:1.23-alpine \
|
golang:1.24-alpine \
|
||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
@@ -220,15 +220,15 @@ jobs:
|
|||||||
raceFlag: "-race"
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -270,15 +270,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -321,15 +321,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -408,15 +408,16 @@ jobs:
|
|||||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -497,15 +498,15 @@ jobs:
|
|||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -561,15 +562,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
uses: android-actions/setup-android@v3
|
uses: android-actions/setup-android@v3
|
||||||
with:
|
with:
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
- name: Setup NDK
|
- name: Setup NDK
|
||||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build android netbird lib
|
- name: build android netbird lib
|
||||||
@@ -56,9 +56,9 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build iOS netbird lib
|
- name: build iOS netbird lib
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -20,7 +20,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
@@ -40,7 +40,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
@@ -67,10 +67,13 @@ jobs:
|
|||||||
- name: Install curl
|
- name: Install curl
|
||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -80,9 +83,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup MySQL privileges
|
- name: Setup MySQL privileges
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: Install golangci-lint
|
- name: Install golangci-lint
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
env:
|
env:
|
||||||
@@ -60,8 +60,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 52428800 ]; then
|
if [ ${SIZE} -gt 57671680 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,14 @@ checked out and set up:
|
|||||||
go mod tidy
|
go mod tidy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
6. Configure Git hooks for automatic linting:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make setup-hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||||
|
|
||||||
### Dev Container Support
|
### Dev Container Support
|
||||||
|
|
||||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||||
|
|||||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
.PHONY: lint lint-all lint-install setup-hooks
|
||||||
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
|
# Install golangci-lint locally if needed
|
||||||
|
$(GOLANGCI_LINT):
|
||||||
|
@echo "Installing golangci-lint..."
|
||||||
|
@mkdir -p ./bin
|
||||||
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
|
# Lint only changed files (fast, for pre-push)
|
||||||
|
lint: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on changed files..."
|
||||||
|
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||||
|
|
||||||
|
# Lint entire codebase (slow, matches CI)
|
||||||
|
lint-all: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on all files..."
|
||||||
|
@$(GOLANGCI_LINT) run --timeout=12m
|
||||||
|
|
||||||
|
# Just install the linter
|
||||||
|
lint-install: $(GOLANGCI_LINT)
|
||||||
|
|
||||||
|
# Setup git hooks for all developers
|
||||||
|
setup-hooks:
|
||||||
|
@git config core.hooksPath .githooks
|
||||||
|
@chmod +x .githooks/pre-push
|
||||||
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
@@ -4,10 +4,13 @@ package android
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -16,10 +19,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/net"
|
"github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -62,17 +68,18 @@ type Client struct {
|
|||||||
deviceName string
|
deviceName string
|
||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
stateFile string
|
||||||
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||||
execWorkaround(androidSDKVersion)
|
execWorkaround(androidSDKVersion)
|
||||||
|
|
||||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: platformFiles.ConfigurationFilePath(),
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
uiVersion: uiVersion,
|
uiVersion: uiVersion,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
@@ -80,11 +87,12 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
|||||||
recorder: peer.NewRecorder(""),
|
recorder: peer.NewRecorder(""),
|
||||||
ctxCancelLock: &sync.Mutex{},
|
ctxCancelLock: &sync.Mutex{},
|
||||||
networkChangeListener: networkChangeListener,
|
networkChangeListener: networkChangeListener,
|
||||||
|
stateFile: platformFiles.StateFilePath(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
@@ -107,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
c.ctxCancelLock.Unlock()
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
auth := NewAuthWithConfig(ctx, cfg)
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
err = auth.login(urlOpener)
|
err = auth.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -156,6 +164,19 @@ func (c *Client) Stop() {
|
|||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) RenewTun(fd int) error {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
return fmt.Errorf("engine not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
e := c.connectClient.Engine()
|
||||||
|
if e == nil {
|
||||||
|
return fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
p.ConnStatus.String(),
|
||||||
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
@@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
log.Error("could not get route selector")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
networkArray := &NetworkArray{
|
networkArray := &NetworkArray{
|
||||||
items: make([]Network, 0),
|
items: make([]Network, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||||
|
|
||||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
r := routes[0]
|
r := routes[0]
|
||||||
|
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||||
netStr := r.Network.String()
|
netStr := r.Network.String()
|
||||||
|
|
||||||
if r.IsDynamic() {
|
if r.IsDynamic() {
|
||||||
netStr = r.Domains.SafeString()
|
netStr = r.Domains.SafeString()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
network := Network{
|
network := Network{
|
||||||
Name: string(id),
|
Name: string(id),
|
||||||
Network: netStr,
|
Network: netStr,
|
||||||
Peer: peer.FQDN,
|
Peer: routePeer.FQDN,
|
||||||
Status: peer.ConnStatus.String(),
|
Status: routePeer.ConnStatus.String(),
|
||||||
|
IsSelected: routeSelector.IsSelected(id),
|
||||||
|
Domains: domains,
|
||||||
}
|
}
|
||||||
networkArray.Add(network)
|
networkArray.Add(network)
|
||||||
}
|
}
|
||||||
@@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() {
|
|||||||
c.recorder.RemoveConnectionListener()
|
c.recorder.RemoveConnectionListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) toggleRoute(command routeCommand) error {
|
||||||
|
return command.toggleRoute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
|
client := c.connectClient
|
||||||
|
if client == nil {
|
||||||
|
return nil, fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := client.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, fmt.Errorf("engine is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := engine.GetRouteManager()
|
||||||
|
if manager == nil {
|
||||||
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SelectRoute(route string) error {
|
||||||
|
manager, err := c.getRouteManager()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) DeselectRoute(route string) error {
|
||||||
|
manager, err := c.getRouteManager()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
|
||||||
|
// with its resolved IP addresses from the provided resolvedDomains map.
|
||||||
|
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
|
||||||
|
domains := NetworkDomains{}
|
||||||
|
|
||||||
|
for _, d := range route.Domains {
|
||||||
|
networkDomain := NetworkDomain{
|
||||||
|
Address: d.SafeString(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if info, exists := resolvedDomains[d]; exists {
|
||||||
|
for _, prefix := range info.Prefixes {
|
||||||
|
networkDomain.addResolvedIP(prefix.Addr().String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domains.Add(&networkDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
func exportEnvList(list *EnvList) {
|
func exportEnvList(list *EnvList) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
|||||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(url string, userCode string)
|
||||||
OnLoginSuccess()
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login try register the client on the server
|
// Login try register the client on the server
|
||||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||||
go func() {
|
go func() {
|
||||||
err := a.login(urlOpener)
|
err := a.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resultListener.OnError(err)
|
resultListener.OnError(err)
|
||||||
} else {
|
} else {
|
||||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||||
|
|||||||
56
client/android/network_domains.go
Normal file
56
client/android/network_domains.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type ResolvedIPs struct {
|
||||||
|
resolvedIPs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||||
|
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Get(i int) (string, error) {
|
||||||
|
if i < 0 || i >= len(r.resolvedIPs) {
|
||||||
|
return "", fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return r.resolvedIPs[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResolvedIPs) Size() int {
|
||||||
|
return len(r.resolvedIPs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkDomain struct {
|
||||||
|
Address string
|
||||||
|
resolvedIPs ResolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
|
||||||
|
d.resolvedIPs.Add(resolvedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
|
||||||
|
return &d.resolvedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkDomains struct {
|
||||||
|
domains []*NetworkDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Add(domain *NetworkDomain) {
|
||||||
|
n.domains = append(n.domains, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
|
||||||
|
if i < 0 || i >= len(n.domains) {
|
||||||
|
return nil, fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return n.domains[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetworkDomains) Size() int {
|
||||||
|
return len(n.domains)
|
||||||
|
}
|
||||||
@@ -3,10 +3,16 @@
|
|||||||
package android
|
package android
|
||||||
|
|
||||||
type Network struct {
|
type Network struct {
|
||||||
Name string
|
Name string
|
||||||
Network string
|
Network string
|
||||||
Peer string
|
Peer string
|
||||||
Status string
|
Status string
|
||||||
|
IsSelected bool
|
||||||
|
Domains NetworkDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Network) GetNetworkDomains() *NetworkDomains {
|
||||||
|
return &n.Domains
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetworkArray struct {
|
type NetworkArray struct {
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
@@ -5,6 +7,11 @@ type PeerInfo struct {
|
|||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
|
Routes PeerRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
|
||||||
|
return &p.Routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoArray is a wrapper of []PeerInfo
|
// PeerInfoArray is a wrapper of []PeerInfo
|
||||||
|
|||||||
20
client/android/peer_routes.go
Normal file
20
client/android/peer_routes.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type PeerRoutes struct {
|
||||||
|
routes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerRoutes) Get(i int) (string, error) {
|
||||||
|
if i < 0 || i >= len(p.routes) {
|
||||||
|
return "", fmt.Errorf("%d is out of range", i)
|
||||||
|
}
|
||||||
|
return p.routes[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerRoutes) Size() int {
|
||||||
|
return len(p.routes)
|
||||||
|
}
|
||||||
10
client/android/platform_files.go
Normal file
10
client/android/platform_files.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
|
||||||
|
// at their default locations due to android OS restrictions.
|
||||||
|
type PlatformFiles interface {
|
||||||
|
ConfigurationFilePath() string
|
||||||
|
StateFilePath() string
|
||||||
|
}
|
||||||
67
client/android/route_command.go
Normal file
67
client/android/route_command.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func executeRouteToggle(id string, manager routemanager.Manager,
|
||||||
|
operationName string,
|
||||||
|
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
|
||||||
|
netID := route.NetID(id)
|
||||||
|
routes := []route.NetID{netID}
|
||||||
|
|
||||||
|
log.Debugf("%s with id: %s", operationName, id)
|
||||||
|
|
||||||
|
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||||
|
log.Debugf("error when %s: %s", operationName, err)
|
||||||
|
return fmt.Errorf("error %s: %w", operationName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.TriggerSelection(manager.GetClientRoutes())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type routeCommand interface {
|
||||||
|
toggleRoute() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type selectRouteCommand struct {
|
||||||
|
route string
|
||||||
|
manager routemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s selectRouteCommand) toggleRoute() error {
|
||||||
|
routeSelector := s.manager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
return fmt.Errorf("no route selector available")
|
||||||
|
}
|
||||||
|
|
||||||
|
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
|
return routeSelector.SelectRoutes(routes, true, allRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
|
||||||
|
}
|
||||||
|
|
||||||
|
type deselectRouteCommand struct {
|
||||||
|
route string
|
||||||
|
manager routemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d deselectRouteCommand) toggleRoute() error {
|
||||||
|
routeSelector := d.manager.GetRouteSelector()
|
||||||
|
if routeSelector == nil {
|
||||||
|
return fmt.Errorf("no route selector available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
|
||||||
|
}
|
||||||
@@ -4,14 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
@@ -106,6 +104,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
|||||||
Username: &username,
|
Username: &username,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||||
|
} else if profileState.Email != "" {
|
||||||
|
loginRequest.Hint = &profileState.Email
|
||||||
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
@@ -241,7 +246,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
|||||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("foreground login failed: %v", err)
|
return fmt.Errorf("foreground login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -269,7 +274,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||||
needsLogin := false
|
needsLogin := false
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
err := WithBackOff(func() error {
|
||||||
@@ -286,7 +291,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if setupKey == "" && needsLogin {
|
if setupKey == "" && needsLogin {
|
||||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -315,8 +320,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
hint := ""
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
profileState, err := pm.GetProfileState(profileName)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||||
|
} else if profileState.Email != "" {
|
||||||
|
hint = profileState.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -357,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
if err := openBrowser(verificationURIComplete); err != nil {
|
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
|
||||||
func openBrowser(url string) error {
|
|
||||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
|
||||||
return exec.Command(browser, url).Start()
|
|
||||||
}
|
|
||||||
return open.Run(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
networkdConf = "/etc/systemd/networkd.conf"
|
||||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||||
@@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no
|
|||||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||||
func configureSystemdNetworkd() error {
|
func configureSystemdNetworkd() error {
|
||||||
parentDir := filepath.Dir(networkdConfDir)
|
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||||
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||||
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gosec // standard networkd permissions
|
||||||
|
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:gosec // standard networkd permissions
|
// nolint:gosec // standard networkd permissions
|
||||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||||
return fmt.Errorf("write networkd configuration: %w", err)
|
return fmt.Errorf("write networkd configuration: %w", err)
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||||
@@ -34,6 +36,7 @@ const (
|
|||||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
disableSSHAuthFlag = "disable-ssh-auth"
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
|
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -47,6 +50,8 @@ var (
|
|||||||
knownHostsFile string
|
knownHostsFile string
|
||||||
identityFile string
|
identityFile string
|
||||||
skipCachedToken bool
|
skipCachedToken bool
|
||||||
|
requestPTY bool
|
||||||
|
sshNoBrowser bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -56,6 +61,7 @@ var (
|
|||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
disableSSHAuth bool
|
disableSSHAuth bool
|
||||||
|
sshJWTCacheTTL int
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -65,14 +71,18 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
|
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||||
|
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||||
|
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||||
@@ -97,9 +107,9 @@ SSH Options:
|
|||||||
-p, --port int Remote SSH port (default 22)
|
-p, --port int Remote SSH port (default 22)
|
||||||
-u, --user string SSH username
|
-u, --user string SSH username
|
||||||
--login string SSH username (alias for --user)
|
--login string SSH username (alias for --user)
|
||||||
|
-t, --tty Force pseudo-terminal allocation
|
||||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||||
-o, --known-hosts string Path to known_hosts file
|
-o, --known-hosts string Path to known_hosts file
|
||||||
-i, --identity string Path to SSH private key file
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
netbird ssh peer-hostname
|
netbird ssh peer-hostname
|
||||||
@@ -107,8 +117,10 @@ Examples:
|
|||||||
netbird ssh --login root peer-hostname
|
netbird ssh --login root peer-hostname
|
||||||
netbird ssh peer-hostname ls -la
|
netbird ssh peer-hostname ls -la
|
||||||
netbird ssh peer-hostname whoami
|
netbird ssh peer-hostname whoami
|
||||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||||
|
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||||
|
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||||
DisableFlagParsing: true,
|
DisableFlagParsing: true,
|
||||||
@@ -143,10 +155,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
|||||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
sshctx, cancel := context.WithCancel(ctx)
|
sshctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
errCh <- err
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -154,6 +166,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
|||||||
select {
|
select {
|
||||||
case <-sig:
|
case <-sig:
|
||||||
cancel()
|
cancel()
|
||||||
|
<-sshctx.Done()
|
||||||
|
return nil
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
case <-sshctx.Done():
|
case <-sshctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
|
|||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||||
|
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||||
|
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
// resetSSHGlobals sets SSH globals to their default values
|
// resetSSHGlobals sets SSH globals to their default values
|
||||||
func resetSSHGlobals() {
|
func resetSSHGlobals() {
|
||||||
port = sshserver.DefaultSSHPort
|
port = sshserver.DefaultSSHPort
|
||||||
@@ -182,6 +213,7 @@ func resetSSHGlobals() {
|
|||||||
strictHostKeyChecking = true
|
strictHostKeyChecking = true
|
||||||
knownHostsFile = ""
|
knownHostsFile = ""
|
||||||
identityFile = ""
|
identityFile = ""
|
||||||
|
sshNoBrowser = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||||
@@ -351,10 +383,12 @@ type sshFlags struct {
|
|||||||
Port int
|
Port int
|
||||||
Username string
|
Username string
|
||||||
Login string
|
Login string
|
||||||
|
RequestPTY bool
|
||||||
StrictHostKeyChecking bool
|
StrictHostKeyChecking bool
|
||||||
KnownHostsFile string
|
KnownHostsFile string
|
||||||
IdentityFile string
|
IdentityFile string
|
||||||
SkipCachedToken bool
|
SkipCachedToken bool
|
||||||
|
NoBrowser bool
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
LogLevel string
|
LogLevel string
|
||||||
LocalForwards []string
|
LocalForwards []string
|
||||||
@@ -366,6 +400,7 @@ type sshFlags struct {
|
|||||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
|
||||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||||
fs.SetOutput(nil)
|
fs.SetOutput(nil)
|
||||||
@@ -373,22 +408,25 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
|||||||
flags := &sshFlags{}
|
flags := &sshFlags{}
|
||||||
|
|
||||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||||
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||||
fs.String("user", "", sshUsernameDesc)
|
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||||
|
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||||
|
|
||||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||||
fs.String("known-hosts", "", "Path to known_hosts file")
|
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||||
fs.String("identity", "", "Path to SSH private key file")
|
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||||
|
|
||||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||||
|
|
||||||
return fs, flags
|
return fs, flags
|
||||||
}
|
}
|
||||||
@@ -409,7 +447,10 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
fs, flags := createSSHFlagSet()
|
fs, flags := createSSHFlagSet()
|
||||||
|
|
||||||
if err := fs.Parse(filteredArgs); err != nil {
|
if err := fs.Parse(filteredArgs); err != nil {
|
||||||
return parseHostnameAndCommand(filteredArgs)
|
if errors.Is(err, flag.ErrHelp) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
remaining := fs.Args()
|
remaining := fs.Args()
|
||||||
@@ -424,10 +465,12 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
username = flags.Login
|
username = flags.Login
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestPTY = flags.RequestPTY
|
||||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||||
knownHostsFile = flags.KnownHostsFile
|
knownHostsFile = flags.KnownHostsFile
|
||||||
identityFile = flags.IdentityFile
|
identityFile = flags.IdentityFile
|
||||||
skipCachedToken = flags.SkipCachedToken
|
skipCachedToken = flags.SkipCachedToken
|
||||||
|
sshNoBrowser = flags.NoBrowser
|
||||||
|
|
||||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||||
configPath = flags.ConfigPath
|
configPath = flags.ConfigPath
|
||||||
@@ -487,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
|||||||
DaemonAddr: daemonAddr,
|
DaemonAddr: daemonAddr,
|
||||||
SkipCachedToken: skipCachedToken,
|
SkipCachedToken: skipCachedToken,
|
||||||
InsecureSkipVerify: !strictHostKeyChecking,
|
InsecureSkipVerify: !strictHostKeyChecking,
|
||||||
|
NoBrowser: sshNoBrowser,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -520,10 +564,29 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
// executeSSHCommand executes a command over SSH.
|
// executeSSHCommand executes a command over SSH.
|
||||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||||
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
var err error
|
||||||
|
if requestPTY {
|
||||||
|
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||||
|
} else {
|
||||||
|
err = c.ExecuteCommandWithIO(ctx, command)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var exitErr *ssh.ExitError
|
||||||
|
if errors.As(err, &exitErr) {
|
||||||
|
os.Exit(exitErr.ExitStatus())
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitMissingErr *ssh.ExitMissingError
|
||||||
|
if errors.As(err, &exitMissingErr) {
|
||||||
|
log.Debugf("Remote command exited without exit status: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Errorf("execute command: %w", err)
|
return fmt.Errorf("execute command: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -535,6 +598,13 @@ func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
|||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var exitMissingErr *ssh.ExitMissingError
|
||||||
|
if errors.As(err, &exitMissingErr) {
|
||||||
|
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Errorf("open terminal: %w", err)
|
return fmt.Errorf("open terminal: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -702,7 +772,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
|||||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
logOutput = firstLogFile
|
logOutput = firstLogFile
|
||||||
}
|
}
|
||||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
|
||||||
|
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
|
||||||
return fmt.Errorf("init log: %w", err)
|
return fmt.Errorf("init log: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,10 +786,23 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("invalid port: %s", portStr)
|
return fmt.Errorf("invalid port: %s", portStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||||
|
// where command-line flags cannot be passed. Default is to open browser.
|
||||||
|
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !noBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create SSH proxy: %w", err)
|
return fmt.Errorf("create SSH proxy: %w", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.Close(); err != nil {
|
||||||
|
log.Debugf("close SSH proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||||
return fmt.Errorf("SSH proxy: %w", err)
|
return fmt.Errorf("SSH proxy: %w", err)
|
||||||
@@ -736,7 +821,8 @@ var sshDetectCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
if err := util.InitLog(detectLogLevel, "console"); err != nil {
|
||||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -745,15 +831,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Debugf("invalid port %q: %v", portStr, err)
|
||||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
|
||||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Debugf("SSH server detection failed: %v", err)
|
||||||
|
cancel()
|
||||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
os.Exit(serverType.ExitCode())
|
os.Exit(serverType.ExitCode())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -51,7 +52,7 @@ func sftpMainDirect(cmd *cobra.Command) error {
|
|||||||
if windowsDomain != "" {
|
if windowsDomain != "" {
|
||||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||||
}
|
}
|
||||||
if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
|
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||||
cmd.PrintErrf("user switching failed\n")
|
cmd.PrintErrf("user switching failed\n")
|
||||||
os.Exit(sshserver.ExitCodeValidationFail)
|
os.Exit(sshserver.ExitCodeValidationFail)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -667,3 +667,51 @@ func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||||
|
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid long flag before hostname",
|
||||||
|
args: []string{"--invalid-flag", "hostname"},
|
||||||
|
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid short flag before hostname",
|
||||||
|
args: []string{"-x", "hostname"},
|
||||||
|
description: "Invalid short flag should return parse error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid flag with value before hostname",
|
||||||
|
args: []string{"--invalid-option=value", "hostname"},
|
||||||
|
description: "Invalid flag with value should return parse error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "typo in known flag",
|
||||||
|
args: []string{"--por", "2222", "hostname"},
|
||||||
|
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset global variables
|
||||||
|
host = ""
|
||||||
|
username = ""
|
||||||
|
port = 22
|
||||||
|
command = ""
|
||||||
|
|
||||||
|
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||||
|
|
||||||
|
// Should return an error for invalid flags
|
||||||
|
assert.Error(t, err, tt.description)
|
||||||
|
|
||||||
|
// Should not have set host to the invalid flag
|
||||||
|
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,6 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -20,8 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -84,7 +88,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -110,13 +113,21 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
ctx := context.Background()
|
||||||
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
|
|
||||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("foreground login failed: %v", err)
|
return fmt.Errorf("foreground login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
|||||||
loginRequest.ProfileName = &activeProf.Name
|
loginRequest.ProfileName = &activeProf.Name
|
||||||
loginRequest.Username = &username
|
loginRequest.Username = &username
|
||||||
|
|
||||||
|
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||||
|
} else if profileState.Email != "" {
|
||||||
|
loginRequest.Hint = &profileState.Email
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
|
||||||
@@ -355,14 +362,18 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
req.EnableSSHSFTP = &enableSSHSFTP
|
req.EnableSSHSFTP = &enableSSHSFTP
|
||||||
}
|
}
|
||||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||||
req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
|
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||||
}
|
}
|
||||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
}
|
}
|
||||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
req.DisableSSHAuth = &disableSSHAuth
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||||
|
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||||
|
}
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
log.Errorf("parse interface name: %v", err)
|
log.Errorf("parse interface name: %v", err)
|
||||||
@@ -467,6 +478,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.DisableSSHAuth = &disableSSHAuth
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -587,6 +602,11 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||||
|
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/nadoo/ipset"
|
ipset "github.com/lrh3321/ipset-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -40,19 +41,13 @@ type aclManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
m := &aclManager{
|
return &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
if err := ipset.Init(); err != nil {
|
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||||
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
specs = append(specs, "-j", actionToStr(action))
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||||
}
|
}
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := ipset.Create(ipsetName); err != nil {
|
if err := m.createIPSet(ipsetName); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
return nil, fmt.Errorf("create ipset: %w", err)
|
||||||
}
|
}
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipList := newIpList(ip.String())
|
ipList := newIpList(ip.String())
|
||||||
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shouldDestroyIpset := false
|
||||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||||
// delete IP from ruleset IPs list and ipset
|
// delete IP from ruleset IPs list and ipset
|
||||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
ip := net.ParseIP(r.ip)
|
||||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
if ip == nil {
|
||||||
|
return fmt.Errorf("parse IP %s", r.ip)
|
||||||
|
}
|
||||||
|
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||||
|
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||||
}
|
}
|
||||||
delete(ipsetList.ips, r.ip)
|
delete(ipsetList.ips, r.ip)
|
||||||
}
|
}
|
||||||
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
// we delete last IP from the set, that means we need to delete
|
// we delete last IP from the set, that means we need to delete
|
||||||
// set itself and associated firewall rule too
|
// set itself and associated firewall rule too
|
||||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||||
|
shouldDestroyIpset = true
|
||||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
|
||||||
log.Errorf("delete empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||||
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shouldDestroyIpset {
|
||||||
|
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||||
|
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("destroy empty ipset: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("destroy empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := ipset.Destroy(ipsetName); err != nil {
|
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m.ipsetStore.deleteIpset(ipsetName)
|
m.ipsetStore.deleteIpset(ipsetName)
|
||||||
}
|
}
|
||||||
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
|
|||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is 0.0.0.0
|
||||||
if ip.String() == "0.0.0.0" {
|
if ip.IsUnspecified() {
|
||||||
matchByIP = false
|
matchByIP = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
|||||||
return ipsetName + actionSuffix
|
return ipsetName + actionSuffix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) createIPSet(name string) error {
|
||||||
|
opts := ipset.CreateOptions{
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("created ipset %s with type hash:net", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||||
|
cidr := uint8(32)
|
||||||
|
if ip.To4() == nil {
|
||||||
|
cidr = 128
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: cidr,
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||||
|
cidr := uint8(32)
|
||||||
|
if ip.To4() == nil {
|
||||||
|
cidr = 128
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: cidr,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Del(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) flushIPSet(name string) error {
|
||||||
|
return ipset.Flush(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) destroyIPSet(name string) error {
|
||||||
|
return ipset.Destroy(name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/nadoo/ipset"
|
ipset "github.com/lrh3321/ipset-go"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
@@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := ipset.Init(); err != nil {
|
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
if err := r.createIPSet(setName); err != nil {
|
||||||
return fmt.Errorf("create set %s: %w", setName, err)
|
return fmt.Errorf("create set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, prefix := range sources {
|
for _, prefix := range sources {
|
||||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string) error {
|
func (r *router) deleteIpSet(setName string) error {
|
||||||
if err := ipset.Destroy(setName); err != nil {
|
if err := r.destroyIPSet(setName); err != nil {
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
|
|
||||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) createIPSet(name string) error {
|
||||||
|
opts := ipset.CreateOptions{
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("created ipset %s with type hash:net", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
ip := addr.AsSlice()
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: uint8(prefix.Bits()),
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) destroyIPSet(name string) error {
|
||||||
|
return ipset.Destroy(name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,7 +27,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
tableRaw = "raw"
|
||||||
|
tableSecurity = "security"
|
||||||
|
|
||||||
chainNameNatPrerouting = "PREROUTING"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
var err error
|
var err error
|
||||||
r.filterTable, err = r.loadFilterTable()
|
r.filterTable, err = r.loadFilterTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errFilterTableNotFound) {
|
log.Debugf("ip filter table not found: %v", err)
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("load filter table: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|||||||
return nil, errFilterTableNotFound
|
return nil, errFilterTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hookName(hook *nftables.ChainHook) string {
|
||||||
|
if hook == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
switch *hook {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
return chainNameForward
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
return chainNameInput
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("hook(%d)", *hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyName(family nftables.TableFamily) string {
|
||||||
|
switch family {
|
||||||
|
case nftables.TableFamilyIPv4:
|
||||||
|
return "ip"
|
||||||
|
case nftables.TableFamilyIPv6:
|
||||||
|
return "ip6"
|
||||||
|
case nftables.TableFamilyINet:
|
||||||
|
return "inet"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("family(%d)", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingFw,
|
Name: chainNameRoutingFw,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
})
|
})
|
||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
|
||||||
return fmt.Errorf("add single nat rule: %v", err)
|
r.addPostroutingRules()
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.addMSSClampingRules(); err != nil {
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("initialize tables: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addPostroutingRules adds the masquerade rules
|
// addPostroutingRules adds the masquerade rules
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() {
|
||||||
// First masquerade rule for traffic coming in from WireGuard interface
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
// Match on the first fwmark
|
// Match on the first fwmark
|
||||||
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
Chain: r.chains[chainNameRoutingNat],
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
Exprs: exprs2,
|
Exprs: exprs2,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
|
|||||||
Exprs: exprsOut,
|
Exprs: exprsOut,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return r.conn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||||
func (r *router) acceptForwardRules() error {
|
func (r *router) acceptForwardRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.acceptFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
|||||||
// Try iptables first and fallback to nftables if iptables is not available
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// filter table exists but iptables is not
|
// iptables is not available but the filter table exists
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables()
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.acceptFilterRulesIptables(ipt)
|
return r.acceptFilterRulesIptables(ipt)
|
||||||
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables forward rule: %v", rule)
|
log.Debugf("added iptables forward rule: %v", rule)
|
||||||
}
|
}
|
||||||
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables input rule: %v", inputRule)
|
log.Debugf("added iptables input rule: %v", inputRule)
|
||||||
}
|
}
|
||||||
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
|||||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesNftables() error {
|
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||||
|
// This is used when iptables is not available.
|
||||||
|
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
forwardChain := &nftables.Chain{
|
||||||
|
Name: chainNameForward,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertForwardAcceptRules(forwardChain, intf)
|
||||||
|
|
||||||
|
inputChain := &nftables.Chain{
|
||||||
|
Name: chainNameInput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertInputAcceptRule(inputChain, intf)
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||||
|
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||||
|
func (r *router) acceptExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
iifRule := &nftables.Rule{
|
iifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
inputRule := &nftables.Rule{
|
inputRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameInput,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
UserData: []byte(userDataAcceptInputRule),
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(inputRule)
|
r.conn.InsertRule(inputRule)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRules() error {
|
func (r *router) removeAcceptFilterRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptFilterRulesNftables()
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.removeAcceptFilterRulesIptables(ipt)
|
return r.removeAcceptFilterRulesIptables(ipt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list chains: %v", err)
|
return fmt.Errorf("list chains: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
if chain.Table.Name != table.Name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1100,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||||
|
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||||
|
// ensuring cleanup works even after a crash or if chains changed.
|
||||||
|
func (r *router) removeExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||||
|
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||||
|
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||||
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
|
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||||
|
|
||||||
|
for _, family := range families {
|
||||||
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get rules: %v", err)
|
log.Debugf("list chains for family %d: %v", family, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, chain := range allChains {
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
if r.isExternalChain(chain) {
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
chains = append(chains, chain)
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
return chains
|
||||||
return fmt.Errorf(flushError, err)
|
}
|
||||||
|
|
||||||
|
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||||
|
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Skip all iptables-managed tables in the ip family
|
||||||
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Type != nftables.ChainTypeFilter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesTable(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
@@ -1128,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(" unable to list rules: %v", err)
|
return fmt.Errorf("list rules: %w", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,7 +11,6 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -20,9 +18,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
|
||||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
|
||||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
|
||||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
|
||||||
conn.Connect()
|
|
||||||
|
|
||||||
state := conn.GetState()
|
|
||||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
|
||||||
if !conn.WaitForStateChange(ctx, state) {
|
|
||||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
|
||||||
}
|
|
||||||
state = conn.GetState()
|
|
||||||
}
|
|
||||||
|
|
||||||
if state == connectivity.Shutdown {
|
|
||||||
return ErrConnectionShutdown
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
@@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := grpc.NewClient(
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(
|
||||||
|
connCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new client: %w", err)
|
return nil, fmt.Errorf("dial context: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -19,11 +20,12 @@ import (
|
|||||||
|
|
||||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||||
type WGTunDevice struct {
|
type WGTunDevice struct {
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu uint16
|
mtu uint16
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
|
// todo: review if we can eliminate the TunAdapter
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
|
|
||||||
@@ -32,17 +34,19 @@ type WGTunDevice struct {
|
|||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
|
renewableTun *RenewableTUN
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
disableDNS: disableDNS,
|
disableDNS: disableDNS,
|
||||||
|
renewableTun: NewRenewableTUN(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = unix.Close(fd)
|
_ = unix.Close(fd)
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||||
|
|
||||||
t.name = name
|
t.name = name
|
||||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
t.filteredDevice = newDeviceFilter(t.renewableTun)
|
||||||
|
|
||||||
log.Debugf("attaching to interface %v", name)
|
log.Debugf("attaching to interface %v", name)
|
||||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||||
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *WGTunDevice) RenewTun(fd int) error {
|
||||||
|
if t.device == nil {
|
||||||
|
return fmt.Errorf("device not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
log.Errorf("failed to renew Android interface: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||||
return t.create()
|
return t.create()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunNetstackDevice) RenewTun(fd int) error {
|
||||||
|
// Doesn't make sense in Android for Netstack.
|
||||||
|
return fmt.Errorf("this function has not been implemented in Netstack for Android")
|
||||||
|
}
|
||||||
|
|||||||
309
client/iface/device/renewable_tun.go
Normal file
309
client/iface/device/renewable_tun.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// closeAwareDevice wraps a tun.Device along with a flag
|
||||||
|
// indicating whether its Close method was called.
|
||||||
|
//
|
||||||
|
// It also redirects tun.Device's Events() to a separate goroutine
|
||||||
|
// and closes it when Close is called.
|
||||||
|
//
|
||||||
|
// The WaitGroup and CloseOnce fields are used to ensure that the
|
||||||
|
// goroutine is awaited and closed only once.
|
||||||
|
type closeAwareDevice struct {
|
||||||
|
isClosed atomic.Bool
|
||||||
|
tun.Device
|
||||||
|
closeEventCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||||
|
return &closeAwareDevice{
|
||||||
|
Device: tunDevice,
|
||||||
|
isClosed: atomic.Bool{},
|
||||||
|
closeEventCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||||
|
// to the given channel (RenewableTUN's events channel).
|
||||||
|
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-c.Device.Events():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ev == tun.EventDown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case out <- ev:
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the underlying Device's Close method
|
||||||
|
// after setting isClosed to true.
|
||||||
|
func (c *closeAwareDevice) Close() (err error) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.isClosed.Store(true)
|
||||||
|
close(c.closeEventCh)
|
||||||
|
err = c.Device.Close()
|
||||||
|
c.wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeAwareDevice) IsClosed() bool {
|
||||||
|
return c.isClosed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
type RenewableTUN struct {
|
||||||
|
devices []*closeAwareDevice
|
||||||
|
mu sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
events chan tun.Event
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRenewableTUN() *RenewableTUN {
|
||||||
|
r := &RenewableTUN{
|
||||||
|
devices: make([]*closeAwareDevice, 0),
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
events: make(chan tun.Event, 16),
|
||||||
|
}
|
||||||
|
r.cond = sync.NewCond(&r.mu)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) File() *os.File {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
file := dev.File()
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from an underlying tun.Device kept in the r.devices slice.
|
||||||
|
// If no device is available, it waits for one to be added via AddDevice().
|
||||||
|
//
|
||||||
|
// On error, it retries reading from the newest device instead of returning the error
|
||||||
|
// if the device is closed; if not, it propagates the error.
|
||||||
|
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
// wait until AddDevice() signals a new device via cond.Broadcast()
|
||||||
|
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = dev.Read(bufs, sizes, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// swap in progress; retry on the newest instead of returning the error
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, err // propagate non-swap error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to underlying tun.Device kept in the r.devices slice.
|
||||||
|
// If no device is available, it waits for one to be added via AddDevice().
|
||||||
|
//
|
||||||
|
// On error, it retries writing to the newest device instead of returning the error
|
||||||
|
// if the device is closed; if not, it propagates the error.
|
||||||
|
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dev.Write(bufs, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) MTU() (int, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mtu, err := dev.MTU()
|
||||||
|
if err == nil {
|
||||||
|
return mtu, nil
|
||||||
|
}
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) Name() (string, error) {
|
||||||
|
for {
|
||||||
|
dev := r.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !r.waitForDevice() {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name, err := dev.Name()
|
||||||
|
if err == nil {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events returns a channel that is fed events from the underlying tun.Device's events channel
|
||||||
|
// once it is added.
|
||||||
|
func (r *RenewableTUN) Events() <-chan tun.Event {
|
||||||
|
return r.events
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) Close() error {
|
||||||
|
// Attempts to set the RenewableTUN closed flag to true.
|
||||||
|
// If it's already true, returns immediately.
|
||||||
|
if !r.closed.CompareAndSwap(false, true) {
|
||||||
|
return nil // already closed: idempotent
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
devices := r.devices
|
||||||
|
r.devices = nil
|
||||||
|
r.cond.Broadcast()
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
log.Debugf("closing %d devices", len(devices))
|
||||||
|
for _, device := range devices {
|
||||||
|
if err := device.Close(); err != nil {
|
||||||
|
log.Debugf("error closing a device: %v", err)
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
close(r.events)
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) AddDevice(device tun.Device) {
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.closed.Load() {
|
||||||
|
r.mu.Unlock()
|
||||||
|
_ = device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var toClose *closeAwareDevice
|
||||||
|
if len(r.devices) > 0 {
|
||||||
|
toClose = r.devices[len(r.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cad := newClosableDevice(device)
|
||||||
|
cad.redirectEvents(r.events)
|
||||||
|
|
||||||
|
r.devices = []*closeAwareDevice{cad}
|
||||||
|
r.cond.Broadcast()
|
||||||
|
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
if toClose != nil {
|
||||||
|
if err := toClose.Close(); err != nil {
|
||||||
|
log.Debugf("error closing last device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) waitForDevice() bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
for len(r.devices) == 0 && !r.closed.Load() {
|
||||||
|
r.cond.Wait()
|
||||||
|
}
|
||||||
|
return !r.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RenewableTUN) peekLast() *closeAwareDevice {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if len(r.devices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.devices[len(r.devices)-1]
|
||||||
|
}
|
||||||
@@ -21,5 +21,6 @@ type WGTunDevice interface {
|
|||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
Device() *wgdevice.Device
|
Device() *wgdevice.Device
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
|
RenewTun(fd int) error
|
||||||
GetICEBind() device.EndpointManager
|
GetICEBind() device.EndpointManager
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
|
|||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
return fmt.Errorf("this function has not implemented on non mobile")
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
return fmt.Errorf("this function has not been implemented on non-android")
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
|
// todo: review does this function really necessary or can we merge it with iOS
|
||||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
|||||||
func (w *WGIface) Create() error {
|
func (w *WGIface) Create() error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
return w.tun.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
|
|||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RenewTun(fd int) error {
|
||||||
|
return fmt.Errorf("this function has not been implemented on this platform")
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -9,13 +10,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
// keep darwin compatibility
|
// keep darwin compatibility
|
||||||
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
addr := "100.64.0.1/8"
|
addr := "100.64.0.1/8"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
|||||||
func Test_CreateInterface(t *testing.T) {
|
func Test_CreateInterface(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||||
wgIP := "10.99.99.1/32"
|
wgIP := "10.99.99.1/32"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||||
wgIP := "10.99.99.5/30"
|
wgIP := "10.99.99.5/30"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
func Test_UpdatePeer(t *testing.T) {
|
func Test_UpdatePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.9/30"
|
wgIP := "10.99.99.9/30"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
func Test_RemovePeer(t *testing.T) {
|
func Test_RemovePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.13/30"
|
wgIP := "10.99.99.13/30"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
peer2wgPort := 33200
|
peer2wgPort := 33200
|
||||||
|
|
||||||
keepAlive := 1 * time.Second
|
keepAlive := 1 * time.Second
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
|
||||||
newNet, err = stdnet.NewNet()
|
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package udpmux
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -12,8 +13,9 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
|||||||
if len(networks) > 0 {
|
if len(networks) > 0 {
|
||||||
if m.params.Net == nil {
|
if m.params.Net == nil {
|
||||||
var err error
|
var err error
|
||||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
|||||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if d.providerConfig.LoginHint != "" {
|
||||||
|
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||||
|
if deviceCode.VerificationURI != "" {
|
||||||
|
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return deviceCode, err
|
return deviceCode, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendLoginHint(uri, loginHint string) string {
|
||||||
|
if uri == "" || loginHint == "" {
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(uri)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
query := parsedURL.Query()
|
||||||
|
query.Set("login_hint", loginHint)
|
||||||
|
parsedURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
return parsedURL.String()
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
form.Add("client_id", d.providerConfig.ClientID)
|
form.Add("client_id", d.providerConfig.ClientID)
|
||||||
|
|||||||
@@ -60,38 +60,45 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
return t.AccessToken
|
return t.AccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||||
|
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||||
|
}
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||||
//
|
//
|
||||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||||
|
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// fallback to device code flow
|
|
||||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||||
log.Debug("falling back to device code flow")
|
log.Debug("falling back to device code flow")
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||||
}
|
}
|
||||||
return pkceFlow, nil
|
return pkceFlow, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch s, ok := gstatus.FromError(err); {
|
switch s, ok := gstatus.FromError(err); {
|
||||||
@@ -107,5 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||||
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
|
|||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
// find the first available redirect URL
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
|
|
||||||
for _, redirectURL := range config.RedirectURLs {
|
for _, redirectURL := range config.RedirectURLs {
|
||||||
if !isRedirectURLPortUsed(redirectURL) {
|
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||||
availableRedirectURL = redirectURL
|
availableRedirectURL = redirectURL
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -102,13 +105,16 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
switch p.providerConfig.LoginFlag {
|
||||||
|
case common.LoginFlagPromptLogin:
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
}
|
case common.LoginFlagMaxAge0:
|
||||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
|
||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if p.providerConfig.LoginHint != "" {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||||
|
}
|
||||||
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|
||||||
@@ -189,17 +195,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
|||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
if authError := query.Get(queryError); authError != "" {
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
if authErrorDesc != "" {
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authentication failed: %s", authError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent timing attacks on the state
|
// Prevent timing attacks on the state
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
return nil, fmt.Errorf("invalid state")
|
return nil, fmt.Errorf("authentication failed: Invalid state")
|
||||||
}
|
}
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
code := query.Get(queryCode)
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, fmt.Errorf("missing code")
|
return nil, fmt.Errorf("authentication failed: missing code")
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
return p.oAuthConfig.Exchange(
|
||||||
@@ -228,7 +237,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||||
@@ -276,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
|
|||||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||||
parsedURL, err := url.Parse(redirectURL)
|
parsedURL, err := url.Parse(redirectURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse redirect URL: %v", err)
|
log.Errorf("failed to parse redirect URL: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
port := parsedURL.Port()
|
||||||
|
|
||||||
|
if isPortInExcludedRange(port, excludedRanges) {
|
||||||
|
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf(":%s", port)
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
@@ -298,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// excludedPortRange represents a range of excluded ports.
|
||||||
|
type excludedPortRange struct {
|
||||||
|
start int
|
||||||
|
end int
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||||
|
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||||
|
if len(excludedRanges) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid port number %s: %v", port, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range excludedRanges {
|
||||||
|
if portNum >= r.start && portNum <= r.end {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
8
client/internal/auth/pkce_flow_other.go
Normal file
8
client/internal/auth/pkce_flow_other.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,8 +2,11 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
loginFlag mgm.LoginFlag
|
loginFlag mgm.LoginFlag
|
||||||
disablePromptLogin bool
|
disablePromptLogin bool
|
||||||
expect string
|
expectContains []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Prompt login",
|
name: "Prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
expect: promptLogin,
|
expectContains: []string{promptLogin},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Max age 0 login",
|
name: "Max age 0",
|
||||||
loginFlag: mgm.LoginFlagMaxAge0,
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
expect: maxAge0,
|
expectContains: []string{maxAge0},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disable prompt login",
|
name: "Disable prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
disablePromptLogin: true,
|
disablePromptLogin: true,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "None flag should not add parameters",
|
||||||
|
loginFlag: mgm.LoginFlagNone,
|
||||||
|
expectContains: []string{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
LoginFlag: tc.loginFlag,
|
LoginFlag: tc.loginFlag,
|
||||||
|
DisablePromptLogin: tc.disablePromptLogin,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tc.disablePromptLogin {
|
for _, expected := range tc.expectContains {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||||
} else {
|
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPortInExcludedRange(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
port string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedBlocked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at start of range",
|
||||||
|
port: "8000",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at end of range",
|
||||||
|
port: "8100",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port before range",
|
||||||
|
port: "7999",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port after range",
|
||||||
|
port: "8101",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port in second range",
|
||||||
|
port: "9050",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port not in any range",
|
||||||
|
port: "8500",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid port string",
|
||||||
|
port: "invalid",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty port string",
|
||||||
|
port: "",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURL string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedUsed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
redirectURL: "http://127.0.0.1:8080/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port actually in use",
|
||||||
|
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port not in use and not excluded",
|
||||||
|
redirectURL: "http://127.0.0.1:65432/",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid URL without port",
|
||||||
|
redirectURL: "not-a-valid-url",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port excluded even if not in use",
|
||||||
|
redirectURL: "http://127.0.0.1:8050/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
86
client/internal/auth/pkce_flow_windows.go
Normal file
86
client/internal/auth/pkce_flow_windows.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
ranges, err := getExcludedPortRangesFromNetsh()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||||
|
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("netsh command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseExcludedPortRanges(string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||||
|
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||||
|
var ranges []excludedPortRange
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||||
|
|
||||||
|
foundHeader := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
|
||||||
|
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||||
|
foundHeader = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundHeader {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(line, "----------") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort, err := strconv.Atoi(fields[0])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
endPort, err := strconv.Atoi(fields[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges, nil
|
||||||
|
}
|
||||||
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
netshOutput string
|
||||||
|
expectedRanges []excludedPortRange
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid netsh output with multiple ranges",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
5357 5357 *
|
||||||
|
50000 50059 *
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 5357, end: 5357},
|
||||||
|
{start: 50000, end: 50059},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty output",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
`,
|
||||||
|
expectedRanges: nil,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single range",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
8080 8090
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 8080, end: 8090},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedRanges, ranges)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||||
|
ranges := getSystemExcludedPortRanges()
|
||||||
|
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||||
|
|
||||||
|
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener1.Close()
|
||||||
|
}()
|
||||||
|
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
availablePort := 65432
|
||||||
|
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||||
|
},
|
||||||
|
UseIDToken: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, flow)
|
||||||
|
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||||
|
"Should skip port in use and select available port")
|
||||||
|
}
|
||||||
@@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
@@ -271,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
@@ -291,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := c.engine
|
|
||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if engine != nil && engine.wgInterface != nil {
|
// todo: consider to remove this condition. Is not thread safe.
|
||||||
|
// We should always call Stop(), but we need to verify that it is idempotent
|
||||||
|
if engine.wgInterface != nil {
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
|||||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||||
|
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||||
|
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
@@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
|||||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||||
|
|
||||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||||
|
|
||||||
|
DNS Configuration
|
||||||
|
The debug bundle includes platform-specific DNS configuration files:
|
||||||
|
|
||||||
|
resolv.conf (Unix systems):
|
||||||
|
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||||
|
- Includes nameserver entries, search domains, and resolver options
|
||||||
|
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||||
|
|
||||||
|
scutil_dns.txt (macOS only):
|
||||||
|
- Contains detailed DNS configuration from scutil --dns
|
||||||
|
- Shows DNS configuration for all network interfaces
|
||||||
|
- Includes search domains, nameservers, and DNS resolver settings
|
||||||
|
- All IP addresses and domain names are anonymized
|
||||||
`
|
`
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
|||||||
if err := g.addFirewallRules(); err != nil {
|
if err := g.addFirewallRules(); err != nil {
|
||||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addDNSInfo(); err != nil {
|
||||||
|
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addReadme() error {
|
func (g *BundleGenerator) addReadme() error {
|
||||||
|
|||||||
53
client/internal/debug/debug_darwin.go
Normal file
53
client/internal/debug/debug_darwin.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||||
|
func (g *BundleGenerator) addDNSInfo() error {
|
||||||
|
if err := g.addResolvConf(); err != nil {
|
||||||
|
log.Errorf("failed to add resolv.conf: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addScutilDNS(); err != nil {
|
||||||
|
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addScutilDNS() error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bytes.TrimSpace(output)) == 0 {
|
||||||
|
return fmt.Errorf("no scutil DNS output")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := string(output)
|
||||||
|
if g.anonymize {
|
||||||
|
content = g.anonymizer.AnonymizeString(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -5,3 +5,7 @@ package debug
|
|||||||
func (g *BundleGenerator) addRoutes() error {
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addDNSInfo() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
16
client/internal/debug/debug_nondarwin.go
Normal file
16
client/internal/debug/debug_nondarwin.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build unix && !darwin && !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||||
|
func (g *BundleGenerator) addDNSInfo() error {
|
||||||
|
if err := g.addResolvConf(); err != nil {
|
||||||
|
log.Errorf("failed to add resolv.conf: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
7
client/internal/debug/debug_nonunix.go
Normal file
7
client/internal/debug/debug_nonunix.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !unix
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addDNSInfo() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
29
client/internal/debug/debug_unix.go
Normal file
29
client/internal/debug/debug_unix.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
//go:build unix && !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const resolvConfPath = "/etc/resolv.conf"
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addResolvConf() error {
|
||||||
|
data, err := os.ReadFile(resolvConfPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := string(data)
|
||||||
|
if g.anonymize {
|
||||||
|
content = g.anonymizer.AnonymizeString(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||||
|
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
|
|||||||
Scope string
|
Scope string
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
UseIDToken bool
|
UseIDToken bool
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
|||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
|
if zone.SkipPTRProcess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
if record.Type != int(dns.TypeA) {
|
if record.Type != int(dns.TypeA) {
|
||||||
continue
|
continue
|
||||||
@@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
|||||||
records := collectPTRRecords(config, network)
|
records := collectPTRRecords(config, network)
|
||||||
|
|
||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
Records: records,
|
Records: records,
|
||||||
|
SearchDomainDisabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||||
|
|||||||
@@ -11,11 +11,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4ReverseZone = ".in-addr.arpa."
|
|
||||||
ipv6ReverseZone = ".ip6.arpa."
|
|
||||||
)
|
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: customZone.SearchDomainDisabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
newNet, err := stdnet.NewNet(nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create stdnet: %v", err)
|
t.Errorf("create stdnet: %v", err)
|
||||||
return
|
return
|
||||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create stdnet: %v", err)
|
t.Fatalf("create stdnet: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
timeoutMsg += " " + peerInfo
|
timeoutMsg += " " + peerInfo
|
||||||
}
|
}
|
||||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||||
logger.Warnf(timeoutMsg)
|
logger.Warn(timeoutMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
|
|||||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||||
|
for i, ip := range ips {
|
||||||
|
ips[i] = ip.Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ func NewEngine(
|
|||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
|
|
||||||
path := sm.GetStatePath()
|
path := sm.GetStatePath()
|
||||||
if runtime.GOOS == "ios" {
|
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||||
if !fileExists(mobileDep.StateFilePath) {
|
if !fileExists(mobileDep.StateFilePath) {
|
||||||
err := createFile(mobileDep.StateFilePath)
|
err := createFile(mobileDep.StateFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
@@ -298,9 +297,6 @@ func (e *Engine) Stop() error {
|
|||||||
|
|
||||||
e.cleanupSSHConfig()
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
|
||||||
e.stopDNSServer()
|
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
log.Warnf("failed to cleanup forward rules: %v", err)
|
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||||
@@ -308,24 +304,29 @@ func (e *Engine) Stop() error {
|
|||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
e.stopDNSForwarder()
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop(e.stateManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
if err := e.removeAllPeers(); err != nil {
|
if err := e.removeAllPeers(); err != nil {
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
log.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop(e.stateManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.stopDNSForwarder()
|
||||||
|
|
||||||
|
// stop/restore DNS after peers are closed but before interface goes down
|
||||||
|
// so dbus and friends don't complain because of a missing interface
|
||||||
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@@ -337,16 +338,18 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer stateCancel()
|
||||||
|
|
||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
log.Errorf("failed to stop state manager: %v", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
timeout := e.calculateShutdownTimeout()
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
@@ -432,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
err := e.rpManager.Run()
|
if err := e.rpManager.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -485,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -750,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
wCfg := update.GetNetbirdConfig()
|
wCfg := update.GetNetbirdConfig()
|
||||||
err := e.updateTURNs(wCfg.GetTurns())
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
@@ -789,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nm := update.GetNetworkMap()
|
nm := update.GetNetworkMap()
|
||||||
if nm == nil {
|
if nm == nil || update.SkipNetworkMapUpdate {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -955,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1192,6 +1200,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||||
|
//nolint
|
||||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||||
if forwarderPort == 0 {
|
if forwarderPort == 0 {
|
||||||
forwarderPort = nbdns.ForwarderClientPort
|
forwarderPort = nbdns.ForwarderClientPort
|
||||||
@@ -1206,7 +1215,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
|||||||
|
|
||||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||||
dnsZone := nbdns.CustomZone{
|
dnsZone := nbdns.CustomZone{
|
||||||
Domain: zone.GetDomain(),
|
Domain: zone.GetDomain(),
|
||||||
|
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||||
|
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||||
}
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
dnsRecord := nbdns.SimpleRecord{
|
dnsRecord := nbdns.SimpleRecord{
|
||||||
@@ -1366,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
@@ -1830,6 +1846,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
|||||||
return e.wgInterface.Address().IP
|
return e.wgInterface.Address().IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) RenewTun(fd int) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
wgInterface := e.wgInterface
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if wgInterface == nil {
|
||||||
|
return fmt.Errorf("wireguard interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgInterface.RenewTun(fd)
|
||||||
|
}
|
||||||
|
|
||||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
func (e *Engine) updateDNSForwarder(
|
func (e *Engine) updateDNSForwarder(
|
||||||
enabled bool,
|
enabled bool,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
type sshServer interface {
|
type sshServer interface {
|
||||||
Start(ctx context.Context, addr netip.AddrPort) error
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
Stop() error
|
Stop() error
|
||||||
|
GetStatus() (bool, []sshserver.SessionInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) setupSSHPortRedirection() error {
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
@@ -234,7 +235,17 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
|||||||
|
|
||||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
server.SetNetstackNet(netstackNet)
|
server.SetNetstackNet(netstackNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.configureSSHServer(server)
|
||||||
|
|
||||||
|
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||||
|
return fmt.Errorf("start SSH server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.sshServer = server
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
if registrar, ok := e.firewall.(interface {
|
if registrar, ok := e.firewall.(interface {
|
||||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
}); ok {
|
}); ok {
|
||||||
@@ -243,17 +254,10 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.configureSSHServer(server)
|
|
||||||
e.sshServer = server
|
|
||||||
|
|
||||||
if err := e.setupSSHPortRedirection(); err != nil {
|
if err := e.setupSSHPortRedirection(); err != nil {
|
||||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
|
||||||
return fmt.Errorf("start SSH server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,3 +340,16 @@ func (e *Engine) stopSSHServer() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSSHServerStatus returns the SSH server status and active sessions
|
||||||
|
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
sshServer := e.sshServer
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if sshServer == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshServer.GetStatus()
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,5 +7,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package internal
|
|||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
79
client/internal/engine_sync_test.go
Normal file
79
client/internal/engine_sync_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
|
||||||
|
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun199",
|
||||||
|
WgAddr: "100.70.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
// Precondition
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
|
||||||
|
SkipNetworkMapUpdate: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when NetworkMap is nil
|
||||||
|
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun198",
|
||||||
|
WgAddr: "100.70.0.2/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33101,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -25,11 +24,18 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -49,7 +55,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -105,6 +110,10 @@ type MockWGIface struct {
|
|||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) RenewTun(_ int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -622,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
defer close(updates)
|
defer close(updates)
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
for msg := range updates {
|
for msg := range updates {
|
||||||
err := msgHandler(msg)
|
err := msgHandler(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -805,7 +814,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1008,7 +1017,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1590,7 +1599,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -1618,13 +1626,19 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
|
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
type wgIfaceBase interface {
|
type wgIfaceBase interface {
|
||||||
Create() error
|
Create() error
|
||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||||
|
RenewTun(fd int) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||||
log.Debugf("Gathering ICE candidates")
|
log.Debugf("Gathering ICE candidates")
|
||||||
|
|
||||||
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,6 +23,8 @@ const (
|
|||||||
iceFailedTimeoutDefault = 6 * time.Second
|
iceFailedTimeoutDefault = 6 * time.Second
|
||||||
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||||
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||||
|
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
|
||||||
|
iceAgentCloseTimeout = 3 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type ThreadSafeAgent struct {
|
type ThreadSafeAgent struct {
|
||||||
@@ -32,18 +35,28 @@ type ThreadSafeAgent struct {
|
|||||||
func (a *ThreadSafeAgent) Close() error {
|
func (a *ThreadSafeAgent) Close() error {
|
||||||
var err error
|
var err error
|
||||||
a.once.Do(func() {
|
a.once.Do(func() {
|
||||||
err = a.Agent.Close()
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- a.Agent.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
case <-time.After(iceAgentCloseTimeout):
|
||||||
|
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
iceFailedTimeout := iceFailedTimeout()
|
iceFailedTimeout := iceFailedTimeout()
|
||||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
|
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,11 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(ifaceBlacklist)
|
return stdnet.NewNet(ctx, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
)
|
||||||
|
|
||||||
|
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||||
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create agent: %w", err)
|
return nil, fmt.Errorf("create agent: %w", err)
|
||||||
}
|
}
|
||||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
|
|
||||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||||
if isController(w.config) {
|
if isController(w.config) {
|
||||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||||
} else {
|
} else {
|
||||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
|
|||||||
DisablePromptLogin bool
|
DisablePromptLogin bool
|
||||||
// LoginFlag is used to configure the PKCE flow login behavior
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
LoginFlag common.LoginFlag
|
LoginFlag common.LoginFlag
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ type ConfigInput struct {
|
|||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
SSHJWTCacheTTL *int
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
RosenpassEnabled *bool
|
RosenpassEnabled *bool
|
||||||
@@ -104,6 +105,7 @@ type Config struct {
|
|||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
SSHJWTCacheTTL *int
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -436,6 +438,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||||
|
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||||
|
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||||
|
|||||||
@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLoginHint retrieves the email from the active profile to use as login_hint.
|
||||||
|
func GetLoginHint() string {
|
||||||
|
pm := NewProfileManager()
|
||||||
|
activeProf, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get active profile for login hint: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return profileState.Email
|
||||||
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(nil)
|
net, err := stdnet.NewNet(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(nil)
|
net, err := stdnet.NewNet(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
@@ -39,6 +38,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
|||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
|
|||||||
218
client/internal/sleep/detector_darwin.go
Normal file
218
client/internal/sleep/detector_darwin.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||||
|
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||||
|
#include <IOKit/IOMessage.h>
|
||||||
|
#include <CoreFoundation/CoreFoundation.h>
|
||||||
|
|
||||||
|
extern void sleepCallbackBridge();
|
||||||
|
extern void poweredOnCallbackBridge();
|
||||||
|
extern void suspendedCallbackBridge();
|
||||||
|
extern void resumedCallbackBridge();
|
||||||
|
|
||||||
|
|
||||||
|
// C global variables for IOKit state
|
||||||
|
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||||
|
static io_object_t g_notifierObject = 0;
|
||||||
|
static io_object_t g_generalInterestNotifier = 0;
|
||||||
|
static io_connect_t g_rootPort = 0;
|
||||||
|
static CFRunLoopRef g_runLoop = NULL;
|
||||||
|
|
||||||
|
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||||
|
switch (messageType) {
|
||||||
|
case kIOMessageSystemWillSleep:
|
||||||
|
sleepCallbackBridge();
|
||||||
|
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||||
|
break;
|
||||||
|
case kIOMessageSystemHasPoweredOn:
|
||||||
|
poweredOnCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsSuspended:
|
||||||
|
suspendedCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsResumed:
|
||||||
|
resumedCallbackBridge();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNotifications() {
|
||||||
|
g_rootPort = IORegisterForSystemPower(
|
||||||
|
NULL,
|
||||||
|
&g_notifyPortRef,
|
||||||
|
(IOServiceInterestCallback)sleepCallback,
|
||||||
|
&g_notifierObject
|
||||||
|
);
|
||||||
|
|
||||||
|
if (g_rootPort == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
g_runLoop = CFRunLoopGetCurrent();
|
||||||
|
CFRunLoopRun();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unregisterNotifications() {
|
||||||
|
CFRunLoopRemoveSource(g_runLoop,
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
IODeregisterForSystemPower(&g_notifierObject);
|
||||||
|
IOServiceClose(g_rootPort);
|
||||||
|
IONotificationPortDestroy(g_notifyPortRef);
|
||||||
|
CFRunLoopStop(g_runLoop);
|
||||||
|
|
||||||
|
g_notifyPortRef = NULL;
|
||||||
|
g_notifierObject = 0;
|
||||||
|
g_rootPort = 0;
|
||||||
|
g_runLoop = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serviceRegistry = make(map[*Detector]struct{})
|
||||||
|
serviceRegistryMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
//export sleepCallbackBridge
|
||||||
|
func sleepCallbackBridge() {
|
||||||
|
log.Info("sleepCallbackBridge event triggered")
|
||||||
|
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeSleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export resumedCallbackBridge
|
||||||
|
func resumedCallbackBridge() {
|
||||||
|
log.Info("resumedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export suspendedCallbackBridge
|
||||||
|
func suspendedCallbackBridge() {
|
||||||
|
log.Info("suspendedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export poweredOnCallbackBridge
|
||||||
|
func poweredOnCallbackBridge() {
|
||||||
|
log.Info("poweredOnCallbackBridge event triggered")
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeWakeUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
callback func(event EventType)
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector() (*Detector, error) {
|
||||||
|
return &Detector{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Register(callback func(event EventType)) error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := serviceRegistry[d]; exists {
|
||||||
|
return fmt.Errorf("detector service already registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
d.callback = callback
|
||||||
|
|
||||||
|
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
|
||||||
|
// CFRunLoop must run on a single fixed OS thread
|
||||||
|
go func() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.registerNotifications()
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Info("sleep detection service started on macOS")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||||
|
// and the runloop is stopped and cleaned up.
|
||||||
|
func (d *Detector) Deregister() error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
_, exists := serviceRegistry[d]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel and remove this detector
|
||||||
|
d.cancel()
|
||||||
|
delete(serviceRegistry, d)
|
||||||
|
|
||||||
|
// If other Detectors still exist, leave IOKit running
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service stopping (deregister)")
|
||||||
|
|
||||||
|
// Deregister IOKit notifications, stop runloop, and free resources
|
||||||
|
C.unregisterNotifications()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) triggerCallback(event EventType) {
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
|
timeout := time.NewTimer(500 * time.Millisecond)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
cb := d.callback
|
||||||
|
go func(callback func(event EventType)) {
|
||||||
|
log.Info("sleep detection event fired")
|
||||||
|
callback(event)
|
||||||
|
close(doneChan)
|
||||||
|
}(cb)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
case <-timeout.C:
|
||||||
|
log.Warnf("sleep callback timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
9
client/internal/sleep/detector_notsupported.go
Normal file
9
client/internal/sleep/detector_notsupported.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func NewDetector() (detector, error) {
|
||||||
|
return nil, fmt.Errorf("sleep not supported on this platform")
|
||||||
|
}
|
||||||
37
client/internal/sleep/service.go
Normal file
37
client/internal/sleep/service.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package sleep
|
||||||
|
|
||||||
|
var (
|
||||||
|
EventTypeUnknown EventType = 0
|
||||||
|
EventTypeSleep EventType = 1
|
||||||
|
EventTypeWakeUp EventType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type EventType int
|
||||||
|
|
||||||
|
type detector interface {
|
||||||
|
Register(callback func(eventType EventType)) error
|
||||||
|
Deregister() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
detector detector
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() (*Service, error) {
|
||||||
|
d, err := NewDetector()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
detector: d,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Register(callback func(eventType EventType)) error {
|
||||||
|
return s.detector.Register(callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Deregister() error {
|
||||||
|
return s.detector.Deregister()
|
||||||
|
}
|
||||||
@@ -4,17 +4,28 @@
|
|||||||
package stdnet
|
package stdnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
const updateInterval = 30 * time.Second
|
const (
|
||||||
|
updateInterval = 30 * time.Second
|
||||||
|
dnsResolveTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoSuitableAddress = errors.New("no suitable address found")
|
||||||
|
|
||||||
// Net is an implementation of the net.Net interface
|
// Net is an implementation of the net.Net interface
|
||||||
// based on functions of the standard net package.
|
// based on functions of the standard net package.
|
||||||
@@ -28,12 +39,19 @@ type Net struct {
|
|||||||
|
|
||||||
// mu is shared between interfaces and lastUpdate
|
// mu is shared between interfaces and lastUpdate
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// ctx is the context for network operations that supports cancellation
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetWithDiscover creates a new StdNet instance.
|
// NewNetWithDiscover creates a new StdNet instance.
|
||||||
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
n := &Net{
|
n := &Net{
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
||||||
// so in android cli use pionDiscover
|
// so in android cli use pionDiscover
|
||||||
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewNet creates a new StdNet instance.
|
// NewNet creates a new StdNet instance.
|
||||||
func NewNet(disallowList []string) (*Net, error) {
|
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
n := &Net{
|
n := &Net{
|
||||||
iFaceDiscover: pionDiscover{},
|
iFaceDiscover: pionDiscover{},
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
return n, n.UpdateInterfaces()
|
return n, n.UpdateInterfaces()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveAddr performs DNS resolution with context support and timeout.
|
||||||
|
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
|
||||||
|
host, portStr, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
|
||||||
|
}
|
||||||
|
if port < 0 || port > 65535 {
|
||||||
|
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet := "ip"
|
||||||
|
switch network {
|
||||||
|
case "tcp4", "udp4":
|
||||||
|
ipNet = "ip4"
|
||||||
|
case "tcp6", "udp6":
|
||||||
|
ipNet = "ip6"
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
addr := netip.IPv4Unspecified()
|
||||||
|
if ipNet == "ip6" {
|
||||||
|
addr = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(addr, uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return netip.AddrPort{}, errNoSuitableAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateInterfaces updates the internal list of network interfaces
|
// UpdateInterfaces updates the internal list of network interfaces
|
||||||
// and associated addresses filtering them by name.
|
// and associated addresses filtering them by name.
|
||||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||||
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
|
||||||
|
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
case "":
|
||||||
|
network = "udp"
|
||||||
|
default:
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := n.resolveAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.UDPAddrFromAddrPort(addrPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
|
||||||
|
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
case "":
|
||||||
|
network = "tcp"
|
||||||
|
default:
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := n.resolveAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.TCPAddrFromAddrPort(addrPort), nil
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
299
client/internal/templates/pkce_auth_msg_test.go
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package templates
|
||||||
|
|
||||||
|
import (
|
||||||
|
"html/template"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]string
|
||||||
|
outputFile string
|
||||||
|
expectedTitle string
|
||||||
|
expectedInContent []string
|
||||||
|
notExpectedInContent []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error_state",
|
||||||
|
data: map[string]string{
|
||||||
|
"Error": "authentication failed: invalid state",
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-error.html",
|
||||||
|
expectedTitle: "Login Failed",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"authentication failed: invalid state",
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success_state",
|
||||||
|
data: map[string]string{
|
||||||
|
// No error field means success
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-success.html",
|
||||||
|
expectedTitle: "Login Successful",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird. You can now close this window.",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_state_timeout",
|
||||||
|
data: map[string]string{
|
||||||
|
"Error": "authentication timeout: request expired after 5 minutes",
|
||||||
|
},
|
||||||
|
outputFile: "pkce-auth-timeout.html",
|
||||||
|
expectedTitle: "Login Failed",
|
||||||
|
expectedInContent: []string{
|
||||||
|
"authentication timeout: request expired after 5 minutes",
|
||||||
|
"Login Failed",
|
||||||
|
},
|
||||||
|
notExpectedInContent: []string{
|
||||||
|
"Login Successful",
|
||||||
|
"Your device is now registered and logged in to NetBird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temp directory for this test
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, tt.outputFile)
|
||||||
|
|
||||||
|
// Create output file
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the template
|
||||||
|
if err := tmpl.Execute(file, tt.data); err != nil {
|
||||||
|
file.Close()
|
||||||
|
t.Fatalf("Failed to execute template: %v", err)
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
|
||||||
|
t.Logf("Generated test output: %s", outputPath)
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(content)
|
||||||
|
|
||||||
|
// Verify file has content
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Output file is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify basic HTML structure
|
||||||
|
basicElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"NetBird",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range basicElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected title
|
||||||
|
if !contains(contentStr, tt.expectedTitle) {
|
||||||
|
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected content is present
|
||||||
|
for _, expected := range tt.expectedInContent {
|
||||||
|
if !contains(contentStr, expected) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify unexpected content is not present
|
||||||
|
for _, notExpected := range tt.notExpectedInContent {
|
||||||
|
if contains(contentStr, notExpected) {
|
||||||
|
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
|
||||||
|
// Test that the template can be parsed without errors
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Template parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty data
|
||||||
|
t.Run("empty_data", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "empty-data.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
if err := tmpl.Execute(file, nil); err != nil {
|
||||||
|
t.Errorf("Template execution with nil data failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with error data
|
||||||
|
t.Run("with_error", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "with-error.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data := map[string]string{
|
||||||
|
"Error": "test error message",
|
||||||
|
}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Errorf("Template execution with error data failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
|
||||||
|
// Test that the template contains expected elements
|
||||||
|
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Template parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("success_content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "success.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data := map[string]string{}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Fatalf("Template execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file and verify it contains expected content
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for success indicators
|
||||||
|
contentStr := string(content)
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Generated HTML is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic HTML structure checks
|
||||||
|
requiredElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"Login Successful",
|
||||||
|
"NetBird",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range requiredElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error_content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
outputPath := filepath.Join(tempDir, "error.html")
|
||||||
|
|
||||||
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create output file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
errorMsg := "test error message"
|
||||||
|
data := map[string]string{
|
||||||
|
"Error": errorMsg,
|
||||||
|
}
|
||||||
|
if err := tmpl.Execute(file, data); err != nil {
|
||||||
|
t.Fatalf("Template execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file and verify it contains expected content
|
||||||
|
content, err := os.ReadFile(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read output file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for error indicators
|
||||||
|
contentStr := string(content)
|
||||||
|
if len(contentStr) == 0 {
|
||||||
|
t.Error("Generated HTML is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic HTML structure checks
|
||||||
|
requiredElements := []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html",
|
||||||
|
"<head>",
|
||||||
|
"<body>",
|
||||||
|
"Login Failed",
|
||||||
|
errorMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, elem := range requiredElements {
|
||||||
|
if !contains(contentStr, elem) {
|
||||||
|
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||||
|
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(fd int32, interfaceName string) error {
|
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||||
|
exportEnvList(envList)
|
||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
@@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
|
|||||||
}
|
}
|
||||||
return netIDs
|
return netIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func exportEnvList(list *EnvList) {
|
||||||
|
if list == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k, v := range list.AllItems() {
|
||||||
|
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
|
||||||
|
log.Debugf("Setting env variable %s: %s", k, v)
|
||||||
|
|
||||||
|
if err := os.Setenv(k, v); err != nil {
|
||||||
|
log.Errorf("could not set env variable %s: %v", k, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Env variable %s was set successfully", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
client/ios/NetBirdSDK/env_list.go
Normal file
34
client/ios/NetBirdSDK/env_list.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package NetBirdSDK
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
// EnvList is an exported struct to be bound by gomobile
|
||||||
|
type EnvList struct {
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvList creates a new EnvList
|
||||||
|
func NewEnvList() *EnvList {
|
||||||
|
return &EnvList{data: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair
|
||||||
|
func (el *EnvList) Put(key, value string) {
|
||||||
|
el.data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key
|
||||||
|
func (el *EnvList) Get(key string) string {
|
||||||
|
return el.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *EnvList) AllItems() map[string]string {
|
||||||
|
return el.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
|
||||||
|
func GetEnvKeyNBForceRelay() string {
|
||||||
|
return peer.EnvKeyNBForceRelay
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import _ "golang.org/x/mobile/bind"
|
import _ "golang.org/x/mobile/bind"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ service DaemonService {
|
|||||||
// Status of the service.
|
// Status of the service.
|
||||||
rpc Status(StatusRequest) returns (StatusResponse) {}
|
rpc Status(StatusRequest) returns (StatusResponse) {}
|
||||||
|
|
||||||
// Down engine work in the daemon.
|
// Down stops engine work in the daemon.
|
||||||
rpc Down(DownRequest) returns (DownResponse) {}
|
rpc Down(DownRequest) returns (DownResponse) {}
|
||||||
|
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
@@ -90,12 +90,29 @@ service DaemonService {
|
|||||||
|
|
||||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
||||||
|
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||||
|
|
||||||
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
message OSLifecycleRequest {
|
||||||
|
// avoid collision with loglevel enum
|
||||||
|
enum CycleType {
|
||||||
|
UNKNOWN = 0;
|
||||||
|
SLEEP = 1;
|
||||||
|
WAKEUP = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CycleType type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OSLifecycleResponse {}
|
||||||
|
|
||||||
|
|
||||||
message LoginRequest {
|
message LoginRequest {
|
||||||
// setupKey netbird setup key.
|
// setupKey netbird setup key.
|
||||||
string setupKey = 1;
|
string setupKey = 1;
|
||||||
@@ -168,11 +185,15 @@ message LoginRequest {
|
|||||||
|
|
||||||
optional int64 mtu = 32;
|
optional int64 mtu = 32;
|
||||||
|
|
||||||
optional bool enableSSHRoot = 33;
|
// hint is used to pre-fill the email/username field during SSO authentication
|
||||||
optional bool enableSSHSFTP = 34;
|
optional string hint = 33;
|
||||||
optional bool enableSSHLocalPortForwarding = 35;
|
|
||||||
optional bool enableSSHRemotePortForwarding = 36;
|
optional bool enableSSHRoot = 34;
|
||||||
optional bool disableSSHAuth = 37;
|
optional bool enableSSHSFTP = 35;
|
||||||
|
optional bool enableSSHLocalPortForwarding = 36;
|
||||||
|
optional bool enableSSHRemotePortForwarding = 37;
|
||||||
|
optional bool disableSSHAuth = 38;
|
||||||
|
optional int32 sshJWTCacheTTL = 39;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@@ -202,7 +223,7 @@ message StatusRequest{
|
|||||||
bool getFullPeerStatus = 1;
|
bool getFullPeerStatus = 1;
|
||||||
bool shouldRunProbes = 2;
|
bool shouldRunProbes = 2;
|
||||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||||
optional bool waitForReady = 3;
|
optional bool waitForReady = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StatusResponse{
|
message StatusResponse{
|
||||||
@@ -277,6 +298,8 @@ message GetConfigResponse {
|
|||||||
bool enableSSHRemotePortForwarding = 23;
|
bool enableSSHRemotePortForwarding = 23;
|
||||||
|
|
||||||
bool disableSSHAuth = 25;
|
bool disableSSHAuth = 25;
|
||||||
|
|
||||||
|
int32 sshJWTCacheTTL = 26;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@@ -340,6 +363,20 @@ message NSGroupState {
|
|||||||
string error = 4;
|
string error = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHSessionInfo contains information about an active SSH session
|
||||||
|
message SSHSessionInfo {
|
||||||
|
string username = 1;
|
||||||
|
string remoteAddress = 2;
|
||||||
|
string command = 3;
|
||||||
|
string jwtUsername = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHServerState contains the latest state of the SSH server
|
||||||
|
message SSHServerState {
|
||||||
|
bool enabled = 1;
|
||||||
|
repeated SSHSessionInfo sessions = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
message FullStatus {
|
message FullStatus {
|
||||||
ManagementState managementState = 1;
|
ManagementState managementState = 1;
|
||||||
@@ -353,6 +390,7 @@ message FullStatus {
|
|||||||
repeated SystemEvent events = 7;
|
repeated SystemEvent events = 7;
|
||||||
|
|
||||||
bool lazyConnectionEnabled = 9;
|
bool lazyConnectionEnabled = 9;
|
||||||
|
SSHServerState sshServerState = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Networks
|
// Networks
|
||||||
@@ -619,9 +657,10 @@ message SetConfigRequest {
|
|||||||
|
|
||||||
optional bool enableSSHRoot = 29;
|
optional bool enableSSHRoot = 29;
|
||||||
optional bool enableSSHSFTP = 30;
|
optional bool enableSSHSFTP = 30;
|
||||||
optional bool enableSSHLocalPortForward = 31;
|
optional bool enableSSHLocalPortForwarding = 31;
|
||||||
optional bool enableSSHRemotePortForward = 32;
|
optional bool enableSSHRemotePortForwarding = 32;
|
||||||
optional bool disableSSHAuth = 33;
|
optional bool disableSSHAuth = 33;
|
||||||
|
optional int32 sshJWTCacheTTL = 34;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SetConfigResponse{}
|
message SetConfigResponse{}
|
||||||
@@ -694,6 +733,8 @@ message GetPeerSSHHostKeyResponse {
|
|||||||
|
|
||||||
// RequestJWTAuthRequest for initiating JWT authentication flow
|
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||||
message RequestJWTAuthRequest {
|
message RequestJWTAuthRequest {
|
||||||
|
// hint for OIDC login_hint parameter (typically email address)
|
||||||
|
optional string hint = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestJWTAuthResponse contains authentication flow information
|
// RequestJWTAuthResponse contains authentication flow information
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type DaemonServiceClient interface {
|
|||||||
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
|
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
|
||||||
// Status of the service.
|
// Status of the service.
|
||||||
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
|
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
|
||||||
// Down engine work in the daemon.
|
// Down stops engine work in the daemon.
|
||||||
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
||||||
@@ -70,6 +70,7 @@ type DaemonServiceClient interface {
|
|||||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||||
|
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@@ -382,6 +383,15 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
|
||||||
|
out := new(OSLifecycleResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@@ -395,7 +405,7 @@ type DaemonServiceServer interface {
|
|||||||
Up(context.Context, *UpRequest) (*UpResponse, error)
|
Up(context.Context, *UpRequest) (*UpResponse, error)
|
||||||
// Status of the service.
|
// Status of the service.
|
||||||
Status(context.Context, *StatusRequest) (*StatusResponse, error)
|
Status(context.Context, *StatusRequest) (*StatusResponse, error)
|
||||||
// Down engine work in the daemon.
|
// Down stops engine work in the daemon.
|
||||||
Down(context.Context, *DownRequest) (*DownResponse, error)
|
Down(context.Context, *DownRequest) (*DownResponse, error)
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
||||||
@@ -438,6 +448,7 @@ type DaemonServiceServer interface {
|
|||||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||||
|
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,6 +549,9 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
|
|||||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@@ -1112,6 +1126,24 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(OSLifecycleRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/NotifyOSLifecycle",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -1239,6 +1271,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "WaitJWTToken",
|
MethodName: "WaitJWTToken",
|
||||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "NotifyOSLifecycle",
|
||||||
|
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{
|
Streams: []grpc.StreamDesc{
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -37,13 +37,18 @@ func (c *jwtCache) store(token string, maxAge time.Duration) {
|
|||||||
|
|
||||||
c.expiresAt = time.Now().Add(maxAge)
|
c.expiresAt = time.Now().Add(maxAge)
|
||||||
|
|
||||||
c.timer = time.AfterFunc(maxAge, func() {
|
var timer *time.Timer
|
||||||
|
timer = time.AfterFunc(maxAge, func() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
if c.timer != timer {
|
||||||
|
return
|
||||||
|
}
|
||||||
c.cleanup()
|
c.cleanup()
|
||||||
c.timer = nil
|
c.timer = nil
|
||||||
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||||
})
|
})
|
||||||
|
c.timer = timer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *jwtCache) get() (string, bool) {
|
func (c *jwtCache) get() (string, bool) {
|
||||||
@@ -70,4 +75,5 @@ func (c *jwtCache) cleanup() {
|
|||||||
if c.enclave != nil {
|
if c.enclave != nil {
|
||||||
c.enclave = nil
|
c.enclave = nil
|
||||||
}
|
}
|
||||||
|
c.expiresAt = time.Time{}
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user