mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
87 Commits
feature/ke
...
handle-use
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa4b8c1d42 | ||
|
|
7682fe2e45 | ||
|
|
c9b2ce08eb | ||
|
|
246abda46d | ||
|
|
e4bc76c4de | ||
|
|
bdb8383485 | ||
|
|
bb40325977 | ||
|
|
8524cc75d6 | ||
|
|
c1f164c9cb | ||
|
|
4e2d075413 | ||
|
|
f89c200ce9 | ||
|
|
d51dc4fd33 | ||
|
|
00dddb9458 | ||
|
|
1a9301b684 | ||
|
|
80d9b5fca5 | ||
|
|
ac0b7dc8cb | ||
|
|
e586eca16c | ||
|
|
892db25021 | ||
|
|
da75a76d41 | ||
|
|
3ac32fd78a | ||
|
|
3aa657599b | ||
|
|
d4e9087f94 | ||
|
|
da8447a67d | ||
|
|
8e3bcd57a2 | ||
|
|
4572c6c1f8 | ||
|
|
01f2b0ecb7 | ||
|
|
442ba7cbc8 | ||
|
|
6c2b364966 | ||
|
|
0f0c7ec2ed | ||
|
|
2dec016201 | ||
|
|
06125acb8d | ||
|
|
a9b9b3fa0a | ||
|
|
cdf57275b7 | ||
|
|
e5e69b1f75 | ||
|
|
8eca83f3cb | ||
|
|
973316d194 | ||
|
|
a0a6ced148 | ||
|
|
0fc6c477a9 | ||
|
|
401a462398 | ||
|
|
a3839a6ef7 | ||
|
|
8aa4f240c7 | ||
|
|
d9686bae92 | ||
|
|
24e19ae287 | ||
|
|
74fde0ea2c | ||
|
|
890e09b787 | ||
|
|
48098c994d | ||
|
|
64f6343fcc | ||
|
|
24713fbe59 | ||
|
|
7794b744f8 | ||
|
|
0d0c30c16d | ||
|
|
b0364da67c | ||
|
|
6dee89379b | ||
|
|
76db4f801a | ||
|
|
6c2ed4b4f2 | ||
|
|
2541c78dd0 | ||
|
|
97b6e79809 | ||
|
|
6ad3847615 | ||
|
|
a4d830ef83 | ||
|
|
9e540cd5b4 | ||
|
|
3027d8f27e | ||
|
|
e69ec6ab6a | ||
|
|
7ddde41c92 | ||
|
|
7ebe58f20a | ||
|
|
9c2c0e7934 | ||
|
|
c6af1037d9 | ||
|
|
5cb9a126f1 | ||
|
|
f40951cdf5 | ||
|
|
6e264d9de7 | ||
|
|
42db9773f4 | ||
|
|
bb9f6f6d0a | ||
|
|
829ce6573e | ||
|
|
a366d9e208 | ||
|
|
e074c24487 | ||
|
|
54fe05f6d8 | ||
|
|
33a155d9aa | ||
|
|
51878659f8 | ||
|
|
c000c05435 | ||
|
|
b39ffef22c | ||
|
|
d96f882acb | ||
|
|
d409219b51 | ||
|
|
8b619a8224 | ||
|
|
ed075bc9b9 | ||
|
|
8eb098d6fd | ||
|
|
68a8687c80 | ||
|
|
f7d97b02fd | ||
|
|
2691e729cd | ||
|
|
b524a9d49d |
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,14 +15,14 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20.x"
|
go-version: "1.20.x"
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
|||||||
12
.github/workflows/golang-test-linux.yml
vendored
12
.github/workflows/golang-test-linux.yml
vendored
@@ -18,13 +18,13 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20.x"
|
go-version: "1.20.x"
|
||||||
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -32,7 +32,7 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||||
@@ -47,13 +47,13 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20.x"
|
go-version: "1.20.x"
|
||||||
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -61,7 +61,7 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
|
|||||||
9
.github/workflows/golangci-lint.yml
vendored
9
.github/workflows/golangci-lint.yml
vendored
@@ -8,14 +8,13 @@ jobs:
|
|||||||
name: lint
|
name: lint
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20.x"
|
go-version: "1.20.x"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v2
|
uses: golangci/golangci-lint-action@v3
|
||||||
with:
|
|
||||||
args: --timeout=6m
|
|
||||||
36
.github/workflows/install-script-test.yml
vendored
Normal file
36
.github/workflows/install-script-test.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
name: Test installation
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "release_files/install.sh"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
jobs:
|
||||||
|
test-install-script:
|
||||||
|
strategy:
|
||||||
|
max-parallel: 2
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest]
|
||||||
|
skip_ui_mode: [true, false]
|
||||||
|
install_binary: [true, false]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: run install script
|
||||||
|
env:
|
||||||
|
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
|
||||||
|
USE_BIN_INSTALL: ${{ matrix.install_binary }}
|
||||||
|
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
|
||||||
|
run: |
|
||||||
|
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
|
||||||
|
cat release_files/install.sh | sh -x
|
||||||
|
|
||||||
|
- name: check cli binary
|
||||||
|
run: command -v netbird
|
||||||
60
.github/workflows/install-test-darwin.yml
vendored
60
.github/workflows/install-test-darwin.yml
vendored
@@ -1,60 +0,0 @@
|
|||||||
name: Test installation Darwin
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- "release_files/install.sh"
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
jobs:
|
|
||||||
install-cli-only:
|
|
||||||
runs-on: macos-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Rename brew package
|
|
||||||
if: ${{ matrix.check_bin_install }}
|
|
||||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
|
||||||
|
|
||||||
- name: Run install script
|
|
||||||
run: |
|
|
||||||
sh ./release_files/install.sh
|
|
||||||
env:
|
|
||||||
SKIP_UI_APP: true
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
if ! command -v netbird &> /dev/null; then
|
|
||||||
echo "Error: netbird is not installed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
install-all:
|
|
||||||
runs-on: macos-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Rename brew package
|
|
||||||
if: ${{ matrix.check_bin_install }}
|
|
||||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
|
||||||
|
|
||||||
- name: Run install script
|
|
||||||
run: |
|
|
||||||
sh ./release_files/install.sh
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
if ! command -v netbird &> /dev/null; then
|
|
||||||
echo "Error: netbird is not installed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
|
||||||
echo "Error: NetBird UI is not installed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
38
.github/workflows/install-test-linux.yml
vendored
38
.github/workflows/install-test-linux.yml
vendored
@@ -1,38 +0,0 @@
|
|||||||
name: Test installation Linux
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- "release_files/install.sh"
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
jobs:
|
|
||||||
install-cli-only:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
check_bin_install: [true, false]
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Rename apt package
|
|
||||||
if: ${{ matrix.check_bin_install }}
|
|
||||||
run: |
|
|
||||||
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
|
||||||
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
|
||||||
|
|
||||||
- name: Run install script
|
|
||||||
run: |
|
|
||||||
sh ./release_files/install.sh
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
if ! command -v netbird &> /dev/null; then
|
|
||||||
echo "Error: netbird is not installed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
62
.github/workflows/release.yml
vendored
62
.github/workflows/release.yml
vendored
@@ -7,9 +7,19 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- '.goreleaser.yml'
|
||||||
|
- '.goreleaser_ui.yaml'
|
||||||
|
- '.goreleaser_ui_darwin.yaml'
|
||||||
|
- '.github/workflows/release.yml'
|
||||||
|
- 'release_files/**'
|
||||||
|
- '**/Dockerfile'
|
||||||
|
- '**/Dockerfile.*'
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.8"
|
SIGN_PIPE_VER: "v0.0.9"
|
||||||
GORELEASER_VER: "v1.14.1"
|
GORELEASER_VER: "v1.14.1"
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -19,20 +29,24 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
-
|
-
|
||||||
name: Checkout
|
name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20"
|
go-version: "1.20"
|
||||||
-
|
-
|
||||||
name: Cache Go modules
|
name: Cache Go modules
|
||||||
uses: actions/cache@v1
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -46,10 +60,10 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
-
|
-
|
||||||
name: Set up QEMU
|
name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v1
|
uses: docker/setup-qemu-action@v2
|
||||||
-
|
-
|
||||||
name: Set up Docker Buildx
|
name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v1
|
uses: docker/setup-buildx-action@v2
|
||||||
-
|
-
|
||||||
name: Login to Docker hub
|
name: Login to Docker hub
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
@@ -72,10 +86,10 @@ jobs:
|
|||||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v2
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --rm-dist
|
args: release --rm-dist ${{ env.flags }}
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
@@ -83,7 +97,7 @@ jobs:
|
|||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
-
|
-
|
||||||
name: upload non tags for debug purposes
|
name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -92,17 +106,19 @@ jobs:
|
|||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20"
|
go-version: "1.20"
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v1
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -116,23 +132,23 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||||
- name: Install rsrc
|
- name: Install rsrc
|
||||||
run: go install github.com/akavel/rsrc@v0.10.2
|
run: go install github.com/akavel/rsrc@v0.10.2
|
||||||
- name: Generate windows rsrc
|
- name: Generate windows rsrc
|
||||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v2
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui.yaml --rm-dist
|
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -141,19 +157,21 @@ jobs:
|
|||||||
release_ui_darwin:
|
release_ui_darwin:
|
||||||
runs-on: macos-11
|
runs-on: macos-11
|
||||||
steps:
|
steps:
|
||||||
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
-
|
-
|
||||||
name: Checkout
|
name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20"
|
go-version: "1.20"
|
||||||
-
|
-
|
||||||
name: Cache Go modules
|
name: Cache Go modules
|
||||||
uses: actions/cache@v1
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -165,15 +183,15 @@ jobs:
|
|||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v2
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
|
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
-
|
-
|
||||||
name: upload non tags for debug purposes
|
name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: release-ui-darwin
|
name: release-ui-darwin
|
||||||
path: dist/
|
path: dist/
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
name: Test Docker Compose Linux
|
name: Test Infrastructure files
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'infrastructure_files/**'
|
||||||
|
- '.github/workflows/test-infrastructure-files.yml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test-docker-compose:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install jq
|
- name: Install jq
|
||||||
@@ -22,12 +24,12 @@ jobs:
|
|||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: "1.20.x"
|
go-version: "1.20.x"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -35,7 +37,7 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: cp setup.env
|
- name: cp setup.env
|
||||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||||
@@ -53,6 +55,7 @@ jobs:
|
|||||||
CI_NETBIRD_MGMT_IDP: "none"
|
CI_NETBIRD_MGMT_IDP: "none"
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||||
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
@@ -68,6 +71,7 @@ jobs:
|
|||||||
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
||||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||||
|
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
|
||||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||||
@@ -90,8 +94,8 @@ jobs:
|
|||||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||||
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||||
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||||
grep UseIDToken management.json | grep false
|
grep UseIDToken management.json | grep false
|
||||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||||
@@ -99,6 +103,12 @@ jobs:
|
|||||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
||||||
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||||
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
||||||
|
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||||
|
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
|
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||||
|
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||||
|
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||||
|
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
|
|
||||||
- name: run docker compose up
|
- name: run docker compose up
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
@@ -113,3 +123,28 @@ jobs:
|
|||||||
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
||||||
test $count -eq 4
|
test $count -eq 4
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
|
|
||||||
|
test-getting-started-script:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install jq
|
||||||
|
run: sudo apt-get install -y jq
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: run script
|
||||||
|
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||||
|
|
||||||
|
- name: test Caddy file gen
|
||||||
|
run: test -f Caddyfile
|
||||||
|
- name: test docker-compose file gen
|
||||||
|
run: test -f docker-compose.yml
|
||||||
|
- name: test management.json file gen
|
||||||
|
run: test -f management.json
|
||||||
|
- name: test turnserver.conf file gen
|
||||||
|
run: test -f turnserver.conf
|
||||||
|
- name: test zitadel.env file gen
|
||||||
|
run: test -f zitadel.env
|
||||||
|
- name: test dashboard.env file gen
|
||||||
|
run: test -f dashboard.env
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -19,3 +19,4 @@ client/.distfiles/
|
|||||||
infrastructure_files/setup.env
|
infrastructure_files/setup.env
|
||||||
infrastructure_files/setup-*.env
|
infrastructure_files/setup-*.env
|
||||||
.vscode
|
.vscode
|
||||||
|
.DS_Store
|
||||||
54
.golangci.yaml
Normal file
54
.golangci.yaml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
run:
|
||||||
|
# Timeout for analysis, e.g. 30s, 5m.
|
||||||
|
# Default: 1m
|
||||||
|
timeout: 6m
|
||||||
|
|
||||||
|
# This file contains only configs which differ from defaults.
|
||||||
|
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||||
|
linters-settings:
|
||||||
|
errcheck:
|
||||||
|
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||||
|
# Such cases aren't reported by default.
|
||||||
|
# Default: false
|
||||||
|
check-type-assertions: false
|
||||||
|
|
||||||
|
govet:
|
||||||
|
# Enable all analyzers.
|
||||||
|
# Default: false
|
||||||
|
enable-all: false
|
||||||
|
enable:
|
||||||
|
- nilness
|
||||||
|
|
||||||
|
linters:
|
||||||
|
disable-all: true
|
||||||
|
enable:
|
||||||
|
## enabled by default
|
||||||
|
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||||
|
- gosimple # specializes in simplifying a code
|
||||||
|
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||||
|
- ineffassign # detects when assignments to existing variables are not used
|
||||||
|
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||||
|
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
||||||
|
- unused # checks for unused constants, variables, functions and types
|
||||||
|
## disable by default but the have interesting results so lets add them
|
||||||
|
- bodyclose # checks whether HTTP response body is closed successfully
|
||||||
|
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||||
|
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||||
|
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||||
|
- wastedassign # wastedassign finds wasted assignment statements
|
||||||
|
issues:
|
||||||
|
# Maximum count of issues with the same text.
|
||||||
|
# Set to 0 to disable.
|
||||||
|
# Default: 3
|
||||||
|
max-same-issues: 5
|
||||||
|
|
||||||
|
exclude-rules:
|
||||||
|
- path: sharedsock/filter.go
|
||||||
|
linters:
|
||||||
|
- unused
|
||||||
|
- path: client/firewall/iptables/rule.go
|
||||||
|
linters:
|
||||||
|
- unused
|
||||||
|
- path: mock.go
|
||||||
|
linters:
|
||||||
|
- nilnil
|
||||||
@@ -377,3 +377,13 @@ uploads:
|
|||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
method: PUT
|
method: PUT
|
||||||
|
|
||||||
|
checksum:
|
||||||
|
extra_files:
|
||||||
|
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||||
|
- glob: ./release_files/install.sh
|
||||||
|
|
||||||
|
release:
|
||||||
|
extra_files:
|
||||||
|
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||||
|
- glob: ./release_files/install.sh
|
||||||
@@ -11,6 +11,8 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
tags:
|
||||||
|
- legacy_appindicator
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
@@ -55,9 +57,6 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- libayatana-appindicator3-1
|
|
||||||
- libgtk-3-dev
|
|
||||||
- libappindicator3-dev
|
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
@@ -75,9 +74,6 @@ nfpms:
|
|||||||
- src: client/ui/disconnected.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- libayatana-appindicator3-1
|
|
||||||
- libgtk-3-dev
|
|
||||||
- libappindicator3-dev
|
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
|
|||||||
95
README.md
95
README.md
@@ -1,6 +1,6 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>:hatching_chick: New Release! Peer expiration.</strong>
|
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
|
||||||
<a href="https://github.com/netbirdio/netbird/releases">
|
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
|
||||||
Learn more
|
Learn more
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
@@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>
|
<strong>
|
||||||
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
|
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
@@ -36,47 +36,62 @@
|
|||||||
|
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
**NetBird is an open-source VPN management platform built on top of WireGuard® making it easy to create secure private networks for your organization or home.**
|
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||||
|
|
||||||
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|
||||||
NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience.
|
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||||
|
|
||||||
**Key features:**
|
|
||||||
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
|
||||||
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
|
||||||
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
|
||||||
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
|
||||||
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
|
||||||
- \[x] Multiuser support - sharing network between multiple users.
|
|
||||||
- \[x] SSO and MFA support.
|
|
||||||
- \[x] Multicloud and hybrid-cloud support.
|
|
||||||
- \[x] Kernel WireGuard usage when possible.
|
|
||||||
- \[x] Access Controls - groups & rules.
|
|
||||||
- \[x] Remote SSH access without managing SSH keys.
|
|
||||||
- \[x] Network Routes.
|
|
||||||
- \[x] Private DNS.
|
|
||||||
- \[x] Network Activity Monitoring.
|
|
||||||
|
|
||||||
**Coming soon:**
|
|
||||||
- \[ ] Mobile clients.
|
|
||||||
|
|
||||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
||||||
|
|
||||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
### Key features
|
||||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
|
||||||
|
|
||||||
### Start using NetBird
|
| Connectivity | Management | Automation | Platforms |
|
||||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
||||||
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||||
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[ ] iOS </ul></li> |
|
||||||
|
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
||||||
|
| | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||||
|
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
||||||
|
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
||||||
|
|
||||||
|
|
||||||
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
|
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||||
|
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||||
|
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||||
|
- Add more machines.
|
||||||
|
|
||||||
|
### Quickstart with self-hosted NetBird
|
||||||
|
|
||||||
|
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||||
|
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||||
|
|
||||||
|
**Infrastructure requirements:**
|
||||||
|
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||||
|
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
|
||||||
|
- **Public domain** name pointing to the VM.
|
||||||
|
|
||||||
|
**Software requirements:**
|
||||||
|
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||||
|
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||||
|
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||||
|
- [curl](https://curl.se/) installed.
|
||||||
|
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||||
|
|
||||||
|
**Steps**
|
||||||
|
- Download and run the installation script:
|
||||||
|
```bash
|
||||||
|
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
|
||||||
|
```
|
||||||
|
- Once finished, you can manage the resources via `docker-compose`
|
||||||
|
|
||||||
### A bit on NetBird internals
|
### A bit on NetBird internals
|
||||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||||
@@ -88,18 +103,18 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
|||||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||||
|
|
||||||
<p float="left" align="middle">
|
<p float="left" align="middle">
|
||||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||||
|
|
||||||
### Roadmap
|
|
||||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
|
||||||
|
|
||||||
### Community projects
|
### Community projects
|
||||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||||
|
|
||||||
|
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||||
|
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||||
|
|
||||||
### Support acknowledgement
|
### Support acknowledgement
|
||||||
|
|
||||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||||
@@ -107,7 +122,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
|||||||

|

|
||||||
|
|
||||||
### Testimonials
|
### Testimonials
|
||||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
||||||
|
|
||||||
### Legal
|
### Legal
|
||||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ func Encode(num uint32) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var encoded strings.Builder
|
var encoded strings.Builder
|
||||||
remainder := uint32(0)
|
|
||||||
|
|
||||||
for num > 0 {
|
for num > 0 {
|
||||||
remainder = num % base
|
remainder := num % base
|
||||||
encoded.WriteByte(alphabet[remainder])
|
encoded.WriteByte(alphabet[remainder])
|
||||||
num /= base
|
num /= base
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
FROM gcr.io/distroless/base:debug
|
FROM alpine:3
|
||||||
|
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
|
|
||||||
SHELL ["/busybox/sh","-c"]
|
|
||||||
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
|
|
||||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||||
COPY netbird /go/bin/netbird
|
COPY netbird /go/bin/netbird
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -35,6 +36,11 @@ type RouteListener interface {
|
|||||||
routemanager.RouteListener
|
routemanager.RouteListener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||||
|
type DnsReadyListener interface {
|
||||||
|
dns.ReadyListener
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@@ -65,7 +71,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener) error {
|
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -90,7 +96,31 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
|
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
|
// In this case make no sense handle registration steps.
|
||||||
|
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||||
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
|
ConfigPath: c.cfgFile,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
//nolint
|
||||||
|
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
c.ctxCancelLock.Lock()
|
||||||
|
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||||
|
defer c.ctxCancel()
|
||||||
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
|
// todo do not throw error in case of cancelled context
|
||||||
|
ctx = internal.CtxInitState(ctx)
|
||||||
|
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -126,6 +156,17 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetConnectionListener set the network connection listener
|
// SetConnectionListener set the network connection listener
|
||||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||||
c.recorder.SetConnectionListener(listener)
|
c.recorder.SetConnectionListener(listener)
|
||||||
|
|||||||
26
client/android/dns_list.go
Normal file
26
client/android/dns_list.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// DNSList is a wrapper of []string
|
||||||
|
type DNSList struct {
|
||||||
|
items []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new DNS address to the collection
|
||||||
|
func (array *DNSList) Add(s string) {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array *DNSList) Get(i int) (string, error) {
|
||||||
|
if i >= len(array.items) || i < 0 {
|
||||||
|
return "", fmt.Errorf("out of range")
|
||||||
|
}
|
||||||
|
return array.items[i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array *DNSList) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
24
client/android/dns_list_test.go
Normal file
24
client/android/dns_list_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDNSList_Get(t *testing.T) {
|
||||||
|
l := DNSList{
|
||||||
|
items: make([]string, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := l.Get(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = l.Get(-1)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = l.Get(1)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,15 +6,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSOListener is async listener for mobile framework
|
// SSOListener is async listener for mobile framework
|
||||||
@@ -87,9 +86,15 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
|||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
supportsSSO = false
|
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
err = nil
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
supportsSSO = false
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -183,27 +188,15 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s, ok := gstatus.FromError(err)
|
return nil, err
|
||||||
if ok && s.Code() == codes.NotFound {
|
|
||||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
|
||||||
"If you are using hosting Netbird see documentation at " +
|
|
||||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
|
||||||
} else if ok && s.Code() == codes.Unimplemented {
|
|
||||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
|
||||||
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||||
|
|
||||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||||
@@ -211,7 +204,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo,
|
|||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,20 +3,21 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
@@ -163,31 +164,15 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s, ok := gstatus.FromError(err)
|
return nil, err
|
||||||
if ok && s.Code() == codes.NotFound {
|
|
||||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
|
||||||
"If you are using hosting Netbird see documentation at " +
|
|
||||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
|
||||||
} else if ok && s.Code() == codes.Unimplemented {
|
|
||||||
mgmtURL := managementURL
|
|
||||||
if mgmtURL == "" {
|
|
||||||
mgmtURL = internal.DefaultManagementURL
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
|
||||||
"please update your servver or use Setup Keys to login", mgmtURL)
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||||
|
|
||||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
@@ -196,7 +181,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -206,15 +191,64 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if !strings.Contains(verificationURIComplete, userCode) {
|
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := open.Run(verificationURIComplete)
|
browserAuthMsg := "Please do the SSO login in your browser. \n" +
|
||||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
verificationURIComplete + " " + codeMsg
|
||||||
if err != nil {
|
|
||||||
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys"
|
||||||
|
|
||||||
|
authenticateUsingBrowser := func() {
|
||||||
|
cmd.Println(browserAuthMsg)
|
||||||
|
cmd.Println("")
|
||||||
|
if err := open.Run(verificationURIComplete); err != nil {
|
||||||
|
cmd.Println(setupKeyAuthMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows", "darwin":
|
||||||
|
authenticateUsingBrowser()
|
||||||
|
case "linux":
|
||||||
|
if isLinuxRunningDesktop() {
|
||||||
|
authenticateUsingBrowser()
|
||||||
|
} else {
|
||||||
|
// If current flow is PKCE, it implies the server is anticipating the redirect to localhost.
|
||||||
|
// Devices lacking browser support are incompatible with this flow.Therefore,
|
||||||
|
// these devices will need to resort to setup keys instead.
|
||||||
|
if isPKCEFlow(verificationURIComplete) {
|
||||||
|
cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" +
|
||||||
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
|
} else {
|
||||||
|
cmd.Println(browserAuthMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment.
|
||||||
|
func isLinuxRunningDesktop() bool {
|
||||||
|
for _, env := range os.Environ() {
|
||||||
|
values := strings.Split(env, "=")
|
||||||
|
if len(values) == 2 {
|
||||||
|
key, value := values[0], values[1]
|
||||||
|
if key == "XDG_CURRENT_DESKTOP" && value != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPKCEFlow determines if the PKCE flow is active or not,
|
||||||
|
// by checking the existence of redirect_uri inside the verification URL.
|
||||||
|
func isPKCEFlow(verificationURL string) bool {
|
||||||
|
if verificationURL == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(verificationURL, "redirect_uri")
|
||||||
|
}
|
||||||
|
|||||||
@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Print(err)
|
log.Debug(err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
cmd.Printf("Couldn't connect. " +
|
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
"You can verify the connection by running:\n\n" +
|
||||||
"Run the status command: \n\n" +
|
" netbird status\n\n")
|
||||||
" netbird status\n\n" +
|
return err
|
||||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|||||||
@@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
|
||||||
resp, _ := getStatus(ctx, cmd)
|
resp, err := getStatus(ctx, cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||||
@@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||||
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
|
||||||
resp.GetStatus(),
|
resp.GetStatus(),
|
||||||
)
|
)
|
||||||
return nil
|
return nil
|
||||||
@@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||||
|
|
||||||
statusOutputString := ""
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ const (
|
|||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
// Netbird client for ACL and routing functionality
|
// Netbird client for ACL and routing functionality
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
AllowNetbird() error
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
@@ -51,6 +54,7 @@ type Manager interface {
|
|||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
direction RuleDirection,
|
||||||
action Action,
|
action Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@@ -60,5 +64,8 @@ type Manager interface {
|
|||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
|
|
||||||
|
// Flush the changes to firewall controller
|
||||||
|
Flush() error
|
||||||
|
|
||||||
// TODO: migrate routemanager firewal actions to this interface
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/nadoo/ipset"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@@ -35,36 +36,53 @@ type Manager struct {
|
|||||||
inputDefaultRuleSpecs []string
|
inputDefaultRuleSpecs []string
|
||||||
outputDefaultRuleSpecs []string
|
outputDefaultRuleSpecs []string
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
|
rulesets map[string]ruleset
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ruleset struct {
|
||||||
|
rule *Rule
|
||||||
|
ips map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
inputDefaultRuleSpecs: []string{
|
inputDefaultRuleSpecs: []string{
|
||||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||||
outputDefaultRuleSpecs: []string{
|
outputDefaultRuleSpecs: []string{
|
||||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||||
|
rulesets: make(map[string]ruleset),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ipset.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// init clients for booth ipv4 and ipv6
|
// init clients for booth ipv4 and ipv6
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
m.ipv4Client = ipv4Client
|
|
||||||
|
|
||||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
if ipv6Supported {
|
||||||
if err != nil {
|
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
if err != nil {
|
||||||
} else {
|
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
|
||||||
m.ipv6Client = ipv6Client
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.ipv4Client == nil && m.ipv6Client == nil {
|
||||||
|
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
if err := m.Reset(); err != nil {
|
||||||
@@ -83,6 +101,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@@ -101,22 +120,45 @@ func (m *Manager) AddFiltering(
|
|||||||
if sPort != nil && sPort.Values != nil {
|
if sPort != nil && sPort.Values != nil {
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||||
}
|
}
|
||||||
|
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
if comment == "" {
|
if comment == "" {
|
||||||
comment = ruleID
|
comment = ruleID
|
||||||
}
|
}
|
||||||
|
|
||||||
specs := m.filterRuleSpecs(
|
if ipsetName != "" {
|
||||||
"filter",
|
rs, rsExists := m.rulesets[ipsetName]
|
||||||
ip,
|
if !rsExists {
|
||||||
string(protocol),
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
sPortVal,
|
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||||
dPortVal,
|
}
|
||||||
direction,
|
if err := ipset.Create(ipsetName); err != nil {
|
||||||
action,
|
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||||
comment,
|
}
|
||||||
)
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsExists {
|
||||||
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
|
rs.ips[ip.String()] = ruleID
|
||||||
|
return &Rule{
|
||||||
|
ruleID: ruleID,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
dst: direction == fw.RuleDirectionOUT,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// this is new ipset so we need to create firewall rule for it
|
||||||
|
}
|
||||||
|
|
||||||
|
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
|
||||||
|
direction, action, comment, ipsetName)
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||||
@@ -144,12 +186,24 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Rule{
|
rule := &Rule{
|
||||||
id: ruleID,
|
ruleID: ruleID,
|
||||||
specs: specs,
|
specs: specs,
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
ipsetName: ipsetName,
|
||||||
v6: ip.To4() == nil,
|
ip: ip.String(),
|
||||||
}, nil
|
dst: direction == fw.RuleDirectionOUT,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}
|
||||||
|
if ipsetName != "" {
|
||||||
|
// ipset name is defined and it means that this rule was created
|
||||||
|
// for it, need to assosiate it with ruleset
|
||||||
|
m.rulesets[ipsetName] = ruleset{
|
||||||
|
rule: rule,
|
||||||
|
ips: map[string]string{rule.ip: ruleID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
@@ -170,6 +224,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
client = m.ipv6Client
|
client = m.ipv6Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||||
|
// delete IP from ruleset IPs list and ipset
|
||||||
|
if _, ok := rs.ips[r.ip]; ok {
|
||||||
|
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||||
|
}
|
||||||
|
delete(rs.ips, r.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if after delete, set still contains other IPs,
|
||||||
|
// no need to delete firewall rule and we should exit here
|
||||||
|
if len(rs.ips) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we delete last IP from the set, that means we need to delete
|
||||||
|
// set itself and assosiated firewall rule too
|
||||||
|
delete(m.rulesets, r.ipsetName)
|
||||||
|
|
||||||
|
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||||
|
log.Errorf("delete empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
r = rs.rule
|
||||||
|
}
|
||||||
|
|
||||||
if r.dst {
|
if r.dst {
|
||||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||||
}
|
}
|
||||||
@@ -193,6 +272,41 @@ func (m *Manager) Reset() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
if m.wgIface.IsUserspaceBind() {
|
||||||
|
_, err := m.AddFiltering(
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
"all",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionIN,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"allow netbird interface traffic",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||||
|
}
|
||||||
|
_, err = m.AddFiltering(
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
"all",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionOUT,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"allow netbird interface traffic",
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// reset firewall chain, clear it and drop it
|
// reset firewall chain, clear it and drop it
|
||||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||||
@@ -233,6 +347,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for ipsetName := range m.rulesets {
|
||||||
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
if err := ipset.Destroy(ipsetName); err != nil {
|
||||||
|
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
delete(m.rulesets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +364,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|||||||
func (m *Manager) filterRuleSpecs(
|
func (m *Manager) filterRuleSpecs(
|
||||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||||
direction fw.RuleDirection, action fw.Action, comment string,
|
direction fw.RuleDirection, action fw.Action, comment string,
|
||||||
|
ipsetName string,
|
||||||
) (specs []string) {
|
) (specs []string) {
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
@@ -249,11 +374,19 @@ func (m *Manager) filterRuleSpecs(
|
|||||||
switch direction {
|
switch direction {
|
||||||
case fw.RuleDirectionIN:
|
case fw.RuleDirectionIN:
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
specs = append(specs, "-s", ip.String())
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-s", ip.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case fw.RuleDirectionOUT:
|
case fw.RuleDirectionOUT:
|
||||||
if matchByIP {
|
if matchByIP {
|
||||||
specs = append(specs, "-d", ip.String())
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-d", ip.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if protocol != "all" {
|
if protocol != "all" {
|
||||||
@@ -301,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
|||||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,3 +468,16 @@ func (m *Manager) actionToStr(action fw.Action) string {
|
|||||||
}
|
}
|
||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||||
|
if ipsetName == "" {
|
||||||
|
return ""
|
||||||
|
} else if sPort != "" && dPort != "" {
|
||||||
|
return ipsetName + "-sport-dport"
|
||||||
|
} else if sPort != "" {
|
||||||
|
return ipsetName + "-sport"
|
||||||
|
} else if dPort != "" {
|
||||||
|
return ipsetName + "-dport"
|
||||||
|
}
|
||||||
|
return ipsetName
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
|||||||
panic("AddressFunc is not set")
|
panic("AddressFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||||
|
|
||||||
func TestIptablesManager(t *testing.T) {
|
func TestIptablesManager(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -53,14 +55,15 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
err := manager.Reset()
|
||||||
t.Errorf("clear the manager state: %v", err)
|
require.NoError(t, err, "clear the manager state")
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -68,7 +71,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
@@ -81,33 +84,31 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
if err := manager.DeleteRule(rule1); err != nil {
|
err := manager.DeleteRule(rule1)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
if err := manager.DeleteRule(rule2); err != nil {
|
err := manager.DeleteRule(rule2)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
|
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@@ -122,6 +123,88 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIptablesManagerIPSet(t *testing.T) {
|
||||||
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mock := &iFaceMock{
|
||||||
|
NameFunc: func() string {
|
||||||
|
return "lo"
|
||||||
|
},
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// just check on the local interface
|
||||||
|
manager, err := Create(mock, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := manager.Reset()
|
||||||
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var rule1 fw.Rule
|
||||||
|
t.Run("add first rule with set", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.2")
|
||||||
|
port := &fw.Port{Values: []int{8080}}
|
||||||
|
rule1, err = manager.AddFiltering(
|
||||||
|
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||||
|
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
|
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||||
|
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||||
|
})
|
||||||
|
|
||||||
|
var rule2 fw.Rule
|
||||||
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.3")
|
||||||
|
port := &fw.Port{
|
||||||
|
Values: []int{443},
|
||||||
|
}
|
||||||
|
rule2, err = manager.AddFiltering(
|
||||||
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||||
|
"default", "accept HTTPS traffic from ports range",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
|
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
|
err := manager.DeleteRule(rule1)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
|
err := manager.DeleteRule(rule2)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reset check", func(t *testing.T) {
|
||||||
|
err = manager.Reset()
|
||||||
|
require.NoError(t, err, "failed to reset")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||||
require.NoError(t, err, "failed to check rule")
|
require.NoError(t, err, "failed to check rule")
|
||||||
@@ -148,14 +231,14 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
err := manager.Reset()
|
||||||
t.Errorf("clear the manager state: %v", err)
|
require.NoError(t, err, "clear the manager state")
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -167,9 +250,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package iptables
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
id string
|
ruleID string
|
||||||
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
|
ip string
|
||||||
dst bool
|
dst bool
|
||||||
v6 bool
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/google/uuid"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@@ -27,13 +29,18 @@ const (
|
|||||||
|
|
||||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
FilterOutputChainName = "netbird-acl-output-filter"
|
||||||
|
|
||||||
|
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
|
sConn *nftables.Conn
|
||||||
tableIPv4 *nftables.Table
|
tableIPv4 *nftables.Table
|
||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
|
|
||||||
@@ -43,6 +50,10 @@ type Manager struct {
|
|||||||
filterInputChainIPv6 *nftables.Chain
|
filterInputChainIPv6 *nftables.Chain
|
||||||
filterOutputChainIPv6 *nftables.Chain
|
filterOutputChainIPv6 *nftables.Chain
|
||||||
|
|
||||||
|
rulesetManager *rulesetManager
|
||||||
|
setRemovedIPs map[string]struct{}
|
||||||
|
setRemoved map[string]*nftables.Set
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,8 +65,23 @@ type iFaceMapper interface {
|
|||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
|
// and is permanent. Using same connection for booth type of operations
|
||||||
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
conn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
|
sConn: sConn,
|
||||||
|
|
||||||
|
rulesetManager: newRuleManager(),
|
||||||
|
setRemovedIPs: map[string]struct{}{},
|
||||||
|
setRemoved: map[string]*nftables.Set{},
|
||||||
|
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +103,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@@ -84,6 +111,7 @@ func (m *Manager) AddFiltering(
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
ipset *nftables.Set
|
||||||
table *nftables.Table
|
table *nftables.Table
|
||||||
chain *nftables.Chain
|
chain *nftables.Chain
|
||||||
)
|
)
|
||||||
@@ -107,6 +135,46 @@ func (m *Manager) AddFiltering(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawIP := ip.To4()
|
||||||
|
if rawIP == nil {
|
||||||
|
rawIP = ip.To16()
|
||||||
|
}
|
||||||
|
|
||||||
|
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||||
|
|
||||||
|
if ipsetName != "" {
|
||||||
|
// if we already have set with given name, just add ip to the set
|
||||||
|
// and return rule with new ID in other case let's create rule
|
||||||
|
// with fresh created set and set element
|
||||||
|
|
||||||
|
var isSetNew bool
|
||||||
|
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||||
|
if err != nil {
|
||||||
|
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
|
}
|
||||||
|
isSetNew = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSetNew {
|
||||||
|
// if we already have nftables rules with set for given direction
|
||||||
|
// just add new rule to the ruleset and return new fw.Rule object
|
||||||
|
|
||||||
|
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||||
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
|
}
|
||||||
|
// if ipset exists but it is not linked to rule for given direction
|
||||||
|
// create new rule for direction and bind ipset to it later
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
ifaceKey = expr.MetaKeyOIFNAME
|
||||||
@@ -146,39 +214,47 @@ func (m *Manager) AddFiltering(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
// in that case not add IP match expression into the rule definition
|
||||||
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
var adrLen, adrOffset uint32
|
addrLen := uint32(len(rawIP))
|
||||||
if ip.To4() == nil {
|
addrOffset := uint32(12)
|
||||||
adrLen = 16
|
if addrLen == 16 {
|
||||||
adrOffset = 8
|
addrOffset = 8
|
||||||
} else {
|
|
||||||
adrLen = 4
|
|
||||||
adrOffset = 12
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// change to destination address position if need
|
// change to destination address position if need
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
adrOffset += adrLen
|
addrOffset += addrLen
|
||||||
}
|
}
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: adrOffset,
|
Offset: addrOffset,
|
||||||
Len: adrLen,
|
Len: addrLen,
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
// add individual IP for match if no ipset defined
|
||||||
|
if ipset == nil {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rawIP,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ipsetName,
|
||||||
|
SetID: ipset.ID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
if sPort != nil && len(sPort.Values) != 0 {
|
||||||
@@ -219,39 +295,76 @@ func (m *Manager) AddFiltering(
|
|||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
id := uuid.New().String()
|
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
|
||||||
|
|
||||||
_ = m.conn.InsertRule(&nftables.Rule{
|
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Position: 0,
|
Position: 0,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
if err := m.conn.Flush(); err != nil {
|
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
list, err := m.conn.GetRules(table, chain)
|
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||||
if err != nil {
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
return nil, err
|
}
|
||||||
|
|
||||||
|
// getRulesetID returns ruleset ID based on given parameters
|
||||||
|
func (m *Manager) getRulesetID(
|
||||||
|
ip net.IP,
|
||||||
|
proto fw.Protocol,
|
||||||
|
sPort *fw.Port,
|
||||||
|
dPort *fw.Port,
|
||||||
|
direction fw.RuleDirection,
|
||||||
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) string {
|
||||||
|
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||||
|
if sPort != nil {
|
||||||
|
rulesetID += sPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
if dPort != nil {
|
||||||
|
rulesetID += dPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
rulesetID += strconv.Itoa(int(action))
|
||||||
|
if ipsetName == "" {
|
||||||
|
return "ip:" + ip.String() + rulesetID
|
||||||
|
}
|
||||||
|
return "set:" + ipsetName + rulesetID
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSet in given table by name
|
||||||
|
func (m *Manager) createSet(
|
||||||
|
table *nftables.Table,
|
||||||
|
rawIP []byte,
|
||||||
|
name string,
|
||||||
|
) (*nftables.Set, error) {
|
||||||
|
keyType := nftables.TypeIPAddr
|
||||||
|
if len(rawIP) == 16 {
|
||||||
|
keyType = nftables.TypeIP6Addr
|
||||||
|
}
|
||||||
|
// else we create new ipset and continue creating rule
|
||||||
|
ipset := &nftables.Set{
|
||||||
|
Name: name,
|
||||||
|
Table: table,
|
||||||
|
Dynamic: true,
|
||||||
|
KeyType: keyType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the rule to the chain
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
rule := &Rule{id: id}
|
return nil, fmt.Errorf("create set: %v", err)
|
||||||
for _, r := range list {
|
|
||||||
if bytes.Equal(r.UserData, userData) {
|
|
||||||
rule.Rule = r
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if rule.Rule == nil {
|
|
||||||
return nil, fmt.Errorf("rule not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush created set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain returns the chain for the given IP address with specific settings
|
// chain returns the chain for the given IP address with specific settings
|
||||||
@@ -268,7 +381,7 @@ func (m *Manager) chain(
|
|||||||
if c != nil {
|
if c != nil {
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.To4() != nil {
|
||||||
@@ -288,13 +401,20 @@ func (m *Manager) chain(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// table returns the table for the given family of the IP address
|
// table returns the table for the given family of the IP address
|
||||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) table(
|
||||||
|
family nftables.TableFamily, tableName string,
|
||||||
|
) (*nftables.Table, error) {
|
||||||
|
// we cache access to Netbird ACL table only
|
||||||
|
if tableName != FilterTableName {
|
||||||
|
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
if family == nftables.TableFamilyIPv4 {
|
if family == nftables.TableFamilyIPv4 {
|
||||||
if m.tableIPv4 != nil {
|
if m.tableIPv4 != nil {
|
||||||
return m.tableIPv4, nil
|
return m.tableIPv4, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -306,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
return m.tableIPv6, nil
|
return m.tableIPv6, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -314,34 +434,41 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
return m.tableIPv6, nil
|
return m.tableIPv6, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) createTableIfNotExists(
|
||||||
tables, err := m.conn.ListTablesOfFamily(family)
|
family nftables.TableFamily, tableName string,
|
||||||
|
) (*nftables.Table, error) {
|
||||||
|
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == FilterTableName {
|
if t.Name == tableName {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
func (m *Manager) createChainIfNotExists(
|
||||||
family nftables.TableFamily,
|
family nftables.TableFamily,
|
||||||
|
tableName string,
|
||||||
name string,
|
name string,
|
||||||
hooknum nftables.ChainHook,
|
hooknum nftables.ChainHook,
|
||||||
priority nftables.ChainPriority,
|
priority nftables.ChainPriority,
|
||||||
chainType nftables.ChainType,
|
chainType nftables.ChainType,
|
||||||
) (*nftables.Chain, error) {
|
) (*nftables.Chain, error) {
|
||||||
table, err := m.table(family)
|
table, err := m.table(family, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of chains: %w", err)
|
return nil, fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
@@ -362,7 +489,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
Policy: &polAccept,
|
Policy: &polAccept,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain = m.conn.AddChain(chain)
|
chain = m.rConn.AddChain(chain)
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
shiftDSTAddr := 0
|
shiftDSTAddr := 0
|
||||||
@@ -429,7 +556,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
@@ -444,12 +571,13 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
},
|
},
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
}
|
}
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
})
|
})
|
||||||
if err := m.conn.Flush(); err != nil {
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,16 +586,58 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
nativeRule, ok := rule.(*Rule)
|
nativeRule, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
if nativeRule.nftRule == nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
if nativeRule.nftSet != nil {
|
||||||
|
// call twice of delete set element raises error
|
||||||
|
// so we need to check if element is already removed
|
||||||
|
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||||
|
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||||
|
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.setRemovedIPs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.rulesetManager.deleteRule(nativeRule) {
|
||||||
|
// deleteRule indicates that we still have IP in the ruleset
|
||||||
|
// it means we should not remove the nftables rule but need to update set
|
||||||
|
// so we prepare IP to be removed from set on the next flush call
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||||
|
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nativeRule.nftRule = nil
|
||||||
|
|
||||||
|
if nativeRule.nftSet != nil {
|
||||||
|
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||||
|
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||||
|
}
|
||||||
|
nativeRule.nftSet = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@@ -475,27 +645,217 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
chains, err := m.conn.ListChains()
|
chains, err := m.rConn.ListChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
|
// delete Netbird allow input traffic rule if it exists
|
||||||
|
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||||
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, r := range rules {
|
||||||
|
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
|
||||||
|
if err := m.rConn.DelRule(r); err != nil {
|
||||||
|
log.Errorf("delete rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||||
m.conn.DelChain(c)
|
m.rConn.DelChain(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tables, err := m.conn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == FilterTableName {
|
if t.Name == FilterTableName {
|
||||||
m.conn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
return m.rConn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
//
|
||||||
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
func (m *Manager) Flush() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set must be removed after flush rule changes
|
||||||
|
// otherwise we will get error
|
||||||
|
for _, s := range m.setRemoved {
|
||||||
|
m.rConn.FlushSet(s)
|
||||||
|
m.rConn.DelSet(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.setRemoved) > 0 {
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.setRemovedIPs = map[string]struct{}{}
|
||||||
|
m.setRemoved = map[string]*nftables.Set{}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
tf := nftables.TableFamilyIPv4
|
||||||
|
if m.wgIface.Address().IP.To4() == nil {
|
||||||
|
tf = nftables.TableFamilyIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
chains, err := m.rConn.ListChainsOfTableFamily(tf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chain *nftables.Chain
|
||||||
|
for _, c := range chains {
|
||||||
|
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||||
|
chain = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain == nil {
|
||||||
|
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||||
|
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.applyAllowNetbirdRules(chain)
|
||||||
|
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) flushWithBackoff() (err error) {
|
||||||
|
backoff := 4
|
||||||
|
backoffTime := 1000 * time.Millisecond
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("failed to flush nftables, retrying...")
|
||||||
|
if i == backoff-1 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(backoffTime)
|
||||||
|
backoffTime = backoffTime * 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
if table == nil || chain == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := m.rConn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range list {
|
||||||
|
if len(rule.UserData) != 0 {
|
||||||
|
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||||
|
log.Errorf("failed to set rule handle: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: []byte(AllowNetbirdInputRuleID),
|
||||||
|
}
|
||||||
|
_ = m.rConn.InsertRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||||
|
ifName := ifname(m.wgIface.Name())
|
||||||
|
for _, rule := range existedRules {
|
||||||
|
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||||
|
if len(rule.Exprs) < 4 {
|
||||||
|
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
func encodePort(port fw.Port) []byte {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
fw.RuleDirectionIN,
|
fw.RuleDirectionIN,
|
||||||
fw.ActionDrop,
|
fw.ActionDrop,
|
||||||
"",
|
"",
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
// test expectations:
|
// test expectations:
|
||||||
// 1) regular rule
|
// 1) regular rule
|
||||||
// 2) "accept extra routed traffic rule" for the interface
|
// 2) "accept extra routed traffic rule" for the interface
|
||||||
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
err = manager.DeleteRule(rule)
|
err = manager.DeleteRule(rule)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// test expectations:
|
// test expectations:
|
||||||
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(); err != nil {
|
||||||
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
if i%100 == 0 {
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ import (
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
*nftables.Rule
|
nftRule *nftables.Rule
|
||||||
id string
|
nftSet *nftables.Set
|
||||||
|
|
||||||
|
ruleID string
|
||||||
|
ip []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||||
|
type nftRuleset struct {
|
||||||
|
nftRule *nftables.Rule
|
||||||
|
nftSet *nftables.Set
|
||||||
|
issuedRules map[string]*Rule
|
||||||
|
rulesetID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type rulesetManager struct {
|
||||||
|
rulesets map[string]*nftRuleset
|
||||||
|
|
||||||
|
nftSetName2rulesetID map[string]string
|
||||||
|
issuedRuleID2rulesetID map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRuleManager() *rulesetManager {
|
||||||
|
return &rulesetManager{
|
||||||
|
rulesets: map[string]*nftRuleset{},
|
||||||
|
|
||||||
|
nftSetName2rulesetID: map[string]string{},
|
||||||
|
issuedRuleID2rulesetID: map[string]string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||||
|
ruleset, ok := r.rulesets[rulesetID]
|
||||||
|
return ruleset, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) createRuleset(
|
||||||
|
rulesetID string,
|
||||||
|
nftRule *nftables.Rule,
|
||||||
|
nftSet *nftables.Set,
|
||||||
|
) *nftRuleset {
|
||||||
|
ruleset := nftRuleset{
|
||||||
|
rulesetID: rulesetID,
|
||||||
|
nftRule: nftRule,
|
||||||
|
nftSet: nftSet,
|
||||||
|
issuedRules: map[string]*Rule{},
|
||||||
|
}
|
||||||
|
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||||
|
if nftSet != nil {
|
||||||
|
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||||
|
}
|
||||||
|
return &ruleset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) addRule(
|
||||||
|
ruleset *nftRuleset,
|
||||||
|
ip []byte,
|
||||||
|
) (*Rule, error) {
|
||||||
|
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||||
|
return nil, fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
nftRule: ruleset.nftRule,
|
||||||
|
nftSet: ruleset.nftSet,
|
||||||
|
ruleID: xid.New().String(),
|
||||||
|
ip: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset.issuedRules[rule.ruleID] = &rule
|
||||||
|
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteRule from ruleset and returns true if contains other rules
|
||||||
|
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||||
|
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset := r.rulesets[rulesetID]
|
||||||
|
if ruleset.nftRule == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||||
|
delete(ruleset.issuedRules, rule.ruleID)
|
||||||
|
|
||||||
|
if len(ruleset.issuedRules) == 0 {
|
||||||
|
delete(r.rulesets, ruleset.rulesetID)
|
||||||
|
if rule.nftSet != nil {
|
||||||
|
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||||
|
//
|
||||||
|
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||||
|
// we set correct handle value to it.
|
||||||
|
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||||
|
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||||
|
ruleset, ok := r.rulesets[string(split[0])]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
*ruleset.nftRule = *nftRule
|
||||||
|
return nil
|
||||||
|
}
|
||||||
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{
|
||||||
|
UserData: []byte(rulesetID),
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||||
|
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_addRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||||
|
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||||
|
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||||
|
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||||
|
|
||||||
|
ruleset2 := &nftRuleset{
|
||||||
|
rulesetID: "ruleset-2",
|
||||||
|
}
|
||||||
|
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||||
|
require.Error(t, err, "addRule() should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
ip2 := []byte("192.168.1.1")
|
||||||
|
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule2, "rule should not be nil")
|
||||||
|
|
||||||
|
hasNext := rulesetManager.deleteRule(rule)
|
||||||
|
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||||
|
|
||||||
|
// Check that the rule is no longer in the manager.
|
||||||
|
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||||
|
|
||||||
|
hasNext = rulesetManager.deleteRule(rule2)
|
||||||
|
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.0.1")
|
||||||
|
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
nftRuleCopy := nftRule
|
||||||
|
nftRuleCopy.Handle = 2
|
||||||
|
nftRuleCopy.UserData = []byte(rulesetID)
|
||||||
|
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||||
|
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||||
|
// check correct work with references
|
||||||
|
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
nftSet := nftables.Set{
|
||||||
|
ID: 2,
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
|
||||||
|
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||||
|
require.True(t, ok, "getRuleset() failed")
|
||||||
|
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||||
|
|
||||||
|
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||||
|
require.False(t, ok, "getRuleset() failed")
|
||||||
|
}
|
||||||
19
client/firewall/uspfilter/allow_netbird.go
Normal file
19
client/firewall/uspfilter/allow_netbird.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//go:build !windows && !linux
|
||||||
|
|
||||||
|
package uspfilter
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.resetHook != nil {
|
||||||
|
return m.resetHook()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
91
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
91
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type action string
|
||||||
|
|
||||||
|
const (
|
||||||
|
addRule action = "add"
|
||||||
|
deleteRule action = "delete"
|
||||||
|
|
||||||
|
firewallRuleName = "Netbird"
|
||||||
|
noRulesMatchCriteria = "No rules match the specified criteria"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||||
|
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
return manageFirewallRule(firewallRuleName,
|
||||||
|
addRule,
|
||||||
|
"dir=in",
|
||||||
|
"enable=yes",
|
||||||
|
"action=allow",
|
||||||
|
"profile=any",
|
||||||
|
"localip="+m.wgIface.Address().IP.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func manageFirewallRule(ruleName string, action action, args ...string) error {
|
||||||
|
active, err := isFirewallRuleActive(ruleName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (action == addRule && !active) || (action == deleteRule && active) {
|
||||||
|
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||||
|
args := append(baseArgs, args...)
|
||||||
|
|
||||||
|
cmd := exec.Command("netsh", args...)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isFirewallRuleActive(ruleName string) (bool, error) {
|
||||||
|
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||||
|
|
||||||
|
cmd := exec.Command("netsh", args...)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
var exitError *exec.ExitError
|
||||||
|
if errors.As(err, &exitError) {
|
||||||
|
// if the firewall rule is not active, we expect last exit code to be 1
|
||||||
|
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
|
||||||
|
if exitStatus == 1 {
|
||||||
|
if strings.Contains(string(output), noRulesMatchCriteria) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(string(output), noRulesMatchCriteria) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@@ -19,15 +19,20 @@ const layerTypeAll = 0
|
|||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
SetFilter(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
|
Address() iface.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RuleSet is a set of rules grouped by a string key
|
||||||
|
type RuleSet map[string]Rule
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules []Rule
|
outgoingRules map[string]RuleSet
|
||||||
incomingRules []Rule
|
incomingRules map[string]RuleSet
|
||||||
rulesIndex map[string]int
|
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
|
wgIface IFaceMapper
|
||||||
|
resetHook func() error
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
@@ -48,7 +53,6 @@ type decoder struct {
|
|||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface IFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -62,6 +66,9 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
outgoingRules: make(map[string]RuleSet),
|
||||||
|
incomingRules: make(map[string]RuleSet),
|
||||||
|
wgIface: iface,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@@ -81,6 +88,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
r := Rule{
|
r := Rule{
|
||||||
@@ -124,15 +132,17 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var p int
|
|
||||||
if direction == fw.RuleDirectionIN {
|
if direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules, r)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.incomingRules) - 1
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules, r)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.outgoingRules) - 1
|
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
m.rulesIndex[r.id] = p
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &r, nil
|
return &r, nil
|
||||||
@@ -148,38 +158,25 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, ok := m.rulesIndex[r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.rulesIndex, r.id)
|
|
||||||
|
|
||||||
var toUpdate []Rule
|
|
||||||
if r.direction == fw.RuleDirectionIN {
|
if r.direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.incomingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.incomingRules[r.ip.String()], r.id)
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.outgoingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(toUpdate); i++ {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Flush() error { return nil }
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.outgoingRules = m.outgoingRules[:0]
|
|
||||||
m.incomingRules = m.incomingRules[:0]
|
|
||||||
m.rulesIndex = make(map[string]int)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
@@ -192,7 +189,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter imlements same logic for booth direction of the traffic
|
// dropFilter imlements same logic for booth direction of the traffic
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@@ -224,37 +221,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
|
|
||||||
// check if IP address match by IP
|
var ip net.IP
|
||||||
|
switch ipLayer {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
if isIncomingPacket {
|
||||||
|
ip = d.ip4.SrcIP
|
||||||
|
} else {
|
||||||
|
ip = d.ip4.DstIP
|
||||||
|
}
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
if isIncomingPacket {
|
||||||
|
ip = d.ip6.SrcIP
|
||||||
|
} else {
|
||||||
|
ip = d.ip6.DstIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// default policy is DROP ALL
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||||
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
switch ipLayer {
|
continue
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@@ -264,38 +273,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false, false
|
||||||
// default policy is DROP ALL
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
@@ -325,19 +332,19 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var toUpdate []Rule
|
|
||||||
if in {
|
if in {
|
||||||
r.direction = fw.RuleDirectionIN
|
r.direction = fw.RuleDirectionIN
|
||||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.incomingRules
|
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.outgoingRules
|
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range toUpdate {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@@ -345,15 +352,24 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
for _, r := range m.incomingRules {
|
for _, arr := range m.incomingRules {
|
||||||
if r.id == hookID {
|
for _, r := range arr {
|
||||||
return m.DeleteRule(&r)
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, r := range m.outgoingRules {
|
for _, arr := range m.outgoingRules {
|
||||||
if r.id == hookID {
|
for _, r := range arr {
|
||||||
return m.DeleteRule(&r)
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetResetHook which will be executed in the end of Reset method
|
||||||
|
func (m *Manager) SetResetHook(hook func() error) {
|
||||||
|
m.resetHook = hook
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(iface.PacketFilter) error
|
SetFilterFunc func(iface.PacketFilter) error
|
||||||
|
AddressFunc func() iface.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||||
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
|||||||
return i.SetFilterFunc(iface)
|
return i.SetFilterFunc(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) Address() iface.WGAddress {
|
||||||
|
if i.AddressFunc == nil {
|
||||||
|
return iface.WGAddress{}
|
||||||
|
}
|
||||||
|
return i.AddressFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerCreate(t *testing.T) {
|
func TestManagerCreate(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
@@ -63,7 +71,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -98,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -111,7 +119,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action = fw.ActionDrop
|
action = fw.ActionDrop
|
||||||
comment = "Test rule 2"
|
comment = "Test rule 2"
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -123,8 +131,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.DeleteRule(rule2)
|
err = m.DeleteRule(rule2)
|
||||||
@@ -133,8 +141,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||||
t.Errorf("rule1 still in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,26 +177,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
incomingRules: []Rule{},
|
incomingRules: map[string]RuleSet{},
|
||||||
outgoingRules: []Rule{},
|
outgoingRules: map[string]RuleSet{},
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule Rule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.incomingRules[0]
|
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(manager.outgoingRules) != 1 {
|
if len(manager.outgoingRules) != 1 {
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.outgoingRules[0]
|
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if !tt.ip.Equal(addedRule.ip) {
|
||||||
@@ -211,17 +222,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure rulesIndex is correctly updated
|
|
||||||
index, ok := manager.rulesIndex[addedRule.id]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("expected rule to be in rulesIndex")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if index != 0 {
|
|
||||||
t.Errorf("expected rule index to be 0, got %d", index)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -244,7 +244,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -256,7 +256,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||||
t.Errorf("rules is not empty")
|
t.Errorf("rules is not empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -282,7 +282,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -346,10 +346,12 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
if rule.id == hookID {
|
for _, rule := range arr {
|
||||||
found = true
|
if rule.id == hookID {
|
||||||
break
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,9 +366,11 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
if rule.id == hookID {
|
for _, rule := range arr {
|
||||||
t.Fatalf("The hook was not removed properly.")
|
if rule.id == hookID {
|
||||||
|
t.Fatalf("The hook was not removed properly.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -394,9 +398,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|||||||
@@ -33,9 +33,22 @@ type Manager interface {
|
|||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
manager firewall.Manager
|
manager firewall.Manager
|
||||||
rulesPairs map[string][]firewall.Rule
|
ipsetCounter int
|
||||||
mutex sync.Mutex
|
rulesPairs map[string][]firewall.Rule
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipsetInfo struct {
|
||||||
|
name string
|
||||||
|
ipCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
|
return &DefaultManager{
|
||||||
|
manager: fm,
|
||||||
|
rulesPairs: make(map[string][]firewall.Rule),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
@@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := d.manager.Flush(); err != nil {
|
||||||
|
log.Error("failed to flush firewall rules: ", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
enableSSH := (networkMap.PeerConfig != nil &&
|
||||||
@@ -108,8 +127,31 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
|
|
||||||
applyFailed := false
|
applyFailed := false
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
newRulePairs := make(map[string][]firewall.Rule)
|
||||||
|
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||||
|
|
||||||
|
// calculate which IP's can be grouped in by which ipset
|
||||||
|
// to do that we use rule selector (which is just rule properties without IP's)
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
|
selector := d.getRuleGroupingSelector(r)
|
||||||
|
ipset, ok := ipsetByRuleSelectors[selector]
|
||||||
|
if !ok {
|
||||||
|
ipset = &ipsetInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ipset.ipCount++
|
||||||
|
ipsetByRuleSelectors[selector] = ipset
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
|
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||||
|
if ipset.name == "" {
|
||||||
|
d.ipsetCounter++
|
||||||
|
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||||
|
}
|
||||||
|
ipsetName := ipset.name
|
||||||
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||||
applyFailed = true
|
applyFailed = true
|
||||||
@@ -154,7 +196,10 @@ func (d *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
|
r *mgmProto.FirewallRule,
|
||||||
|
ipsetName string,
|
||||||
|
) (string, []firewall.Rule, error) {
|
||||||
ip := net.ParseIP(r.PeerIP)
|
ip := net.ParseIP(r.PeerIP)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
@@ -190,9 +235,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
|||||||
var err error
|
var err error
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.FirewallRule_IN:
|
case mgmProto.FirewallRule_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.FirewallRule_OUT:
|
case mgmProto.FirewallRule_OUT:
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@@ -205,9 +250,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
|||||||
return ruleID, rules, nil
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addInRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -217,7 +270,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -225,9 +279,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return append(rules, rule), nil
|
return append(rules, rule), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addOutRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -237,7 +299,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@@ -282,6 +345,10 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
in := protoMatch{}
|
in := protoMatch{}
|
||||||
out := protoMatch{}
|
out := protoMatch{}
|
||||||
|
|
||||||
|
// trace which type of protocols was squashed
|
||||||
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
|
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||||
// But we zeroed the IP's for protocol if:
|
// But we zeroed the IP's for protocol if:
|
||||||
@@ -298,12 +365,22 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = map[string]int{}
|
||||||
}
|
}
|
||||||
match := protocols[r.Protocol]
|
|
||||||
|
|
||||||
if _, ok := match[r.PeerIP]; ok {
|
// special case, when we recieve this all network IP address
|
||||||
|
// it means that rules for that protocol was already optimized on the
|
||||||
|
// management side
|
||||||
|
if r.PeerIP == "0.0.0.0" {
|
||||||
|
squashedRules = append(squashedRules, r)
|
||||||
|
squashedProtocols[r.Protocol] = struct{}{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
match[r.PeerIP] = i
|
|
||||||
|
ipset := protocols[r.Protocol]
|
||||||
|
|
||||||
|
if _, ok := ipset[r.PeerIP]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipset[r.PeerIP] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
@@ -324,9 +401,6 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
mgmProto.FirewallRule_UDP,
|
mgmProto.FirewallRule_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||||
@@ -382,6 +456,11 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
return append(rules, squashedRules...), squashedProtocols
|
return append(rules, squashedRules...), squashedProtocols
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
|
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||||
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.FirewallRule_TCP:
|
case mgmProto.FirewallRule_TCP:
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,10 +19,10 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &DefaultManager{
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
manager: fm,
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
}
|
||||||
}, nil
|
return newDefaultManager(fm), nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,30 +7,69 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create creates a firewall manager instance for the Linux
|
// Create creates a firewall manager instance for the Linux
|
||||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
func Create(iface IFaceMapper) (*DefaultManager, error) {
|
||||||
|
// on the linux system we try to user nftables or iptables
|
||||||
|
// in any case, because we need to allow netbird interface traffic
|
||||||
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
|
// for the userspace packet filtering firewall
|
||||||
var fm firewall.Manager
|
var fm firewall.Manager
|
||||||
if iface.IsUserspaceBind() {
|
var err error
|
||||||
// use userspace packet filtering firewall
|
|
||||||
if fm, err = uspfilter.Create(iface); err != nil {
|
checkResult := checkfw.Check()
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
switch checkResult {
|
||||||
return nil, err
|
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||||
|
log.Debug("creating an iptables firewall manager for access control")
|
||||||
|
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||||
|
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
|
||||||
|
log.Infof("failed to create iptables manager for access control: %s", err)
|
||||||
}
|
}
|
||||||
} else {
|
case checkfw.NFTABLES:
|
||||||
|
log.Debug("creating an nftables firewall manager for access control")
|
||||||
if fm, err = nftables.Create(iface); err != nil {
|
if fm, err = nftables.Create(iface); err != nil {
|
||||||
log.Debugf("failed to create nftables manager: %s", err)
|
log.Debugf("failed to create nftables manager for access control: %s", err)
|
||||||
// fallback to iptables
|
|
||||||
if fm, err = iptables.Create(iface); err != nil {
|
|
||||||
log.Errorf("failed to create iptables manager: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DefaultManager{
|
var resetHookForUserspace func() error
|
||||||
manager: fm,
|
if fm != nil && err == nil {
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
// err shadowing is used here, to ignore this error
|
||||||
}, nil
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
resetHookForUserspace = fm.Reset
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface.IsUserspaceBind() {
|
||||||
|
// use userspace packet filtering firewall
|
||||||
|
usfm, err := uspfilter.Create(iface)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set kernel space firewall Reset as hook for userspace firewall
|
||||||
|
// manager Reset method, to clean up
|
||||||
|
if resetHookForUserspace != nil {
|
||||||
|
usfm.SetResetHook(resetHookForUserspace)
|
||||||
|
}
|
||||||
|
|
||||||
|
// to be consistent for any future extensions.
|
||||||
|
// ignore this error
|
||||||
|
if err := usfm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
fm = usfm
|
||||||
|
}
|
||||||
|
|
||||||
|
if fm == nil || err != nil {
|
||||||
|
log.Errorf("failed to create firewall manager: %s", err)
|
||||||
|
// no firewall manager found or initialized correctly
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newDefaultManager(fm), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||||
// iface.EXPECT().Name().Return("lo")
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
iface.EXPECT().SetFilter(gomock.Any())
|
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(iface)
|
acl, err := Create(ifaceMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create ACL manager: %v", err)
|
t.Errorf("create ACL manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||||
// iface.EXPECT().Name().Return("lo")
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
iface.EXPECT().SetFilter(gomock.Any())
|
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(iface)
|
acl, err := Create(ifaceMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create ACL manager: %v", err)
|
t.Errorf("create ACL manager: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
202
client/internal/auth/device_flow.go
Normal file
202
client/internal/auth/device_flow.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostedGrantType grant type for device flow on Hosted
|
||||||
|
const (
|
||||||
|
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||||
|
|
||||||
|
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||||
|
// for the Device Authorization Flow.
|
||||||
|
type DeviceAuthorizationFlow struct {
|
||||||
|
providerConfig internal.DeviceAuthProviderConfig
|
||||||
|
|
||||||
|
HTTPClient HTTPClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||||
|
type RequestDeviceCodePayload struct {
|
||||||
|
Audience string `json:"audience"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRequestPayload used for requesting the auth0 token
|
||||||
|
type TokenRequestPayload struct {
|
||||||
|
GrantType string `json:"grant_type"`
|
||||||
|
DeviceCode string `json:"device_code,omitempty"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRequestResponse used for parsing Hosted token's response
|
||||||
|
type TokenRequestResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
TokenInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||||
|
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||||
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: httpTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DeviceAuthorizationFlow{
|
||||||
|
providerConfig: config,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientID returns the provider client id
|
||||||
|
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||||
|
return d.providerConfig.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||||
|
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("client_id", d.providerConfig.ClientID)
|
||||||
|
form.Add("audience", d.providerConfig.Audience)
|
||||||
|
form.Add("scope", d.providerConfig.Scope)
|
||||||
|
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
|
||||||
|
strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
res, err := d.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := AuthFlowInfo{}
|
||||||
|
err = json.Unmarshal(body, &deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||||
|
if deviceCode.VerificationURIComplete == "" {
|
||||||
|
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceCode, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Add("client_id", d.providerConfig.ClientID)
|
||||||
|
form.Add("grant_type", HostedGrantType)
|
||||||
|
form.Add("device_code", info.DeviceCode)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
res, err := d.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode > 499 {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse := TokenRequestResponse{}
|
||||||
|
err = json.Unmarshal(body, &tokenResponse)
|
||||||
|
if err != nil {
|
||||||
|
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
|
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||||
|
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return TokenInfo{}, ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
|
||||||
|
tokenResponse, err := d.requestToken(info)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResponse.Error != "" {
|
||||||
|
if tokenResponse.Error == "authorization_pending" {
|
||||||
|
continue
|
||||||
|
} else if tokenResponse.Error == "slow_down" {
|
||||||
|
interval = interval + (3 * time.Second)
|
||||||
|
ticker.Reset(interval)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo := TokenInfo{
|
||||||
|
AccessToken: tokenResponse.AccessToken,
|
||||||
|
TokenType: tokenResponse.TokenType,
|
||||||
|
RefreshToken: tokenResponse.RefreshToken,
|
||||||
|
IDToken: tokenResponse.IDToken,
|
||||||
|
ExpiresIn: tokenResponse.ExpiresIn,
|
||||||
|
UseIDToken: d.providerConfig.UseIDToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenInfo, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
package internal
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -53,7 +53,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
testingErrFunc require.ErrorAssertionFunc
|
testingErrFunc require.ErrorAssertionFunc
|
||||||
expectedErrorMSG string
|
expectedErrorMSG string
|
||||||
testingFunc require.ComparisonAssertionFunc
|
testingFunc require.ComparisonAssertionFunc
|
||||||
expectedOut DeviceAuthInfo
|
expectedOut AuthFlowInfo
|
||||||
expectedMSG string
|
expectedMSG string
|
||||||
expectPayload string
|
expectPayload string
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
expectPayload: expectPayload,
|
expectPayload: expectPayload,
|
||||||
}
|
}
|
||||||
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
testCase4Out := AuthFlowInfo{ExpiresIn: 10}
|
||||||
testCase4 := test{
|
testCase4 := test{
|
||||||
name: "Got Device Code",
|
name: "Got Device Code",
|
||||||
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
||||||
@@ -113,8 +113,8 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
err: testCase.inputReqError,
|
err: testCase.inputReqError,
|
||||||
}
|
}
|
||||||
|
|
||||||
hosted := Hosted{
|
deviceFlow := &DeviceAuthorizationFlow{
|
||||||
providerConfig: ProviderConfig{
|
providerConfig: internal.DeviceAuthProviderConfig{
|
||||||
Audience: expectedAudience,
|
Audience: expectedAudience,
|
||||||
ClientID: expectedClientID,
|
ClientID: expectedClientID,
|
||||||
Scope: expectedScope,
|
Scope: expectedScope,
|
||||||
@@ -125,7 +125,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
HTTPClient: &httpClient,
|
HTTPClient: &httpClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
||||||
@@ -145,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
inputMaxReqs int
|
inputMaxReqs int
|
||||||
inputCountResBody string
|
inputCountResBody string
|
||||||
inputTimeout time.Duration
|
inputTimeout time.Duration
|
||||||
inputInfo DeviceAuthInfo
|
inputInfo AuthFlowInfo
|
||||||
inputAudience string
|
inputAudience string
|
||||||
testingErrFunc require.ErrorAssertionFunc
|
testingErrFunc require.ErrorAssertionFunc
|
||||||
expectedErrorMSG string
|
expectedErrorMSG string
|
||||||
@@ -155,7 +155,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
expectPayload string
|
expectPayload string
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInfo := DeviceAuthInfo{
|
defaultInfo := AuthFlowInfo{
|
||||||
DeviceCode: "test",
|
DeviceCode: "test",
|
||||||
ExpiresIn: 10,
|
ExpiresIn: 10,
|
||||||
Interval: 1,
|
Interval: 1,
|
||||||
@@ -278,8 +278,8 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
countResBody: testCase.inputCountResBody,
|
countResBody: testCase.inputCountResBody,
|
||||||
}
|
}
|
||||||
|
|
||||||
hosted := Hosted{
|
deviceFlow := DeviceAuthorizationFlow{
|
||||||
providerConfig: ProviderConfig{
|
providerConfig: internal.DeviceAuthProviderConfig{
|
||||||
Audience: testCase.inputAudience,
|
Audience: testCase.inputAudience,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
@@ -287,11 +287,12 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
Scope: "openid",
|
Scope: "openid",
|
||||||
UseIDToken: false,
|
UseIDToken: false,
|
||||||
},
|
},
|
||||||
HTTPClient: &httpClient}
|
HTTPClient: &httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
|
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
||||||
88
client/internal/auth/oauth.go
Normal file
88
client/internal/auth/oauth.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
|
||||||
|
type OAuthFlow interface {
|
||||||
|
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
|
||||||
|
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
|
||||||
|
GetClientID(ctx context.Context) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient http client interface for API calls
|
||||||
|
type HTTPClient interface {
|
||||||
|
Do(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||||
|
type AuthFlowInfo struct {
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claims used when validating the access token
|
||||||
|
type Claims struct {
|
||||||
|
Audience interface{} `json:"aud"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenInfo holds information of issued access token
|
||||||
|
type TokenInfo struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
UseIDToken bool `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||||
|
func (t TokenInfo) GetTokenToUse() string {
|
||||||
|
if t.UseIDToken {
|
||||||
|
return t.IDToken
|
||||||
|
}
|
||||||
|
return t.AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
||||||
|
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||||
|
log.Debug("loading pkce authorization flow info")
|
||||||
|
|
||||||
|
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
|
if err == nil {
|
||||||
|
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("loading pkce authorization flow info failed with error: %v", err)
|
||||||
|
log.Debugf("falling back to device authorization flow info")
|
||||||
|
|
||||||
|
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||||
|
}
|
||||||
252
client/internal/auth/pkce_flow.go
Normal file
252
client/internal/auth/pkce_flow.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
queryState = "state"
|
||||||
|
queryCode = "code"
|
||||||
|
queryError = "error"
|
||||||
|
queryErrorDesc = "error_description"
|
||||||
|
defaultPKCETimeoutSeconds = 300
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||||
|
// the Authorization Code Flow with PKCE.
|
||||||
|
type PKCEAuthorizationFlow struct {
|
||||||
|
providerConfig internal.PKCEAuthProviderConfig
|
||||||
|
state string
|
||||||
|
codeVerifier string
|
||||||
|
oAuthConfig *oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||||
|
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
|
var availableRedirectURL string
|
||||||
|
|
||||||
|
// find the first available redirect URL
|
||||||
|
for _, redirectURL := range config.RedirectURLs {
|
||||||
|
if !isRedirectURLPortUsed(redirectURL) {
|
||||||
|
availableRedirectURL = redirectURL
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if availableRedirectURL == "" {
|
||||||
|
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &oauth2.Config{
|
||||||
|
ClientID: config.ClientID,
|
||||||
|
ClientSecret: config.ClientSecret,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: config.AuthorizationEndpoint,
|
||||||
|
TokenURL: config.TokenEndpoint,
|
||||||
|
},
|
||||||
|
RedirectURL: availableRedirectURL,
|
||||||
|
Scopes: strings.Split(config.Scope, " "),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PKCEAuthorizationFlow{
|
||||||
|
providerConfig: config,
|
||||||
|
oAuthConfig: cfg,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientID returns the provider client id
|
||||||
|
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
||||||
|
return p.providerConfig.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestAuthInfo requests a authorization code login flow information.
|
||||||
|
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
||||||
|
state, err := randomBytesInHex(24)
|
||||||
|
if err != nil {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||||
|
}
|
||||||
|
p.state = state
|
||||||
|
|
||||||
|
codeVerifier, err := randomBytesInHex(64)
|
||||||
|
if err != nil {
|
||||||
|
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
|
||||||
|
}
|
||||||
|
p.codeVerifier = codeVerifier
|
||||||
|
|
||||||
|
codeChallenge := createCodeChallenge(codeVerifier)
|
||||||
|
authURL := p.oAuthConfig.AuthCodeURL(
|
||||||
|
state,
|
||||||
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||||
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthFlowInfo{
|
||||||
|
VerificationURIComplete: authURL,
|
||||||
|
ExpiresIn: defaultPKCETimeoutSeconds,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||||
|
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||||
|
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||||
|
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
|
go p.startServer(tokenChan, errChan)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return TokenInfo{}, ctx.Err()
|
||||||
|
case token := <-tokenChan:
|
||||||
|
return p.handleOAuthToken(token)
|
||||||
|
case err := <-errChan:
|
||||||
|
return TokenInfo{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
|
go func() {
|
||||||
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
||||||
|
query := req.URL.Query()
|
||||||
|
|
||||||
|
if authError := query.Get(queryError); authError != "" {
|
||||||
|
authErrorDesc := query.Get(queryErrorDesc)
|
||||||
|
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent timing attacks on state
|
||||||
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid state")
|
||||||
|
}
|
||||||
|
|
||||||
|
code := query.Get(queryCode)
|
||||||
|
if code == "" {
|
||||||
|
return nil, fmt.Errorf("missing code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.oAuthConfig.Exchange(
|
||||||
|
req.Context(),
|
||||||
|
code,
|
||||||
|
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := tokenValidatorFunc()
|
||||||
|
if err != nil {
|
||||||
|
renderPKCEFlowTmpl(w, err)
|
||||||
|
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
renderPKCEFlowTmpl(w, nil)
|
||||||
|
tokenChan <- token
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
if err := server.Shutdown(context.Background()); err != nil {
|
||||||
|
log.Errorf("error while shutting down pkce flow server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||||
|
tokenInfo := TokenInfo{
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
TokenType: token.TokenType,
|
||||||
|
ExpiresIn: token.Expiry.Second(),
|
||||||
|
UseIDToken: p.providerConfig.UseIDToken,
|
||||||
|
}
|
||||||
|
if idToken, ok := token.Extra("id_token").(string); ok {
|
||||||
|
tokenInfo.IDToken = idToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCodeChallenge(codeVerifier string) string {
|
||||||
|
sha2 := sha256.Sum256([]byte(codeVerifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
||||||
|
func isRedirectURLPortUsed(redirectURL string) bool {
|
||||||
|
parsedURL, err := url.Parse(redirectURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse redirect URL: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("error while closing the connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||||
|
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make(map[string]string)
|
||||||
|
if authError != nil {
|
||||||
|
data["Error"] = authError.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpl.Execute(w, data); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
62
client/internal/auth/util.go
Normal file
62
client/internal/auth/util.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func randomBytesInHex(count int) (string, error) {
|
||||||
|
buf := make([]byte, count)
|
||||||
|
_, err := io.ReadFull(rand.Reader, buf)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hex.EncodeToString(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidAccessToken is a simple validation of the access token
|
||||||
|
func isValidAccessToken(token string, audience string) error {
|
||||||
|
if token == "" {
|
||||||
|
return fmt.Errorf("token received is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
encodedClaims := strings.Split(token, ".")[1]
|
||||||
|
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := Claims{}
|
||||||
|
err = json.Unmarshal(claimsString, &claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Audience == nil {
|
||||||
|
return fmt.Errorf("required token field audience is absent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audience claim of JWT can be a string or an array of strings
|
||||||
|
typ := reflect.TypeOf(claims.Audience)
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
if claims.Audience == audience {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
for _, aud := range claims.Audience.([]interface{}) {
|
||||||
|
if audience == aud {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid JWT token audience field")
|
||||||
|
}
|
||||||
3
client/internal/checkfw/check.go
Normal file
3
client/internal/checkfw/check.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package checkfw
|
||||||
56
client/internal/checkfw/check_linux.go
Normal file
56
client/internal/checkfw/check_linux.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package checkfw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||||
|
UNKNOWN FWType = iota
|
||||||
|
// IPTABLES is the value for the iptables firewall type
|
||||||
|
IPTABLES
|
||||||
|
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
|
||||||
|
IPTABLESWITHV6
|
||||||
|
// NFTABLES is the value for the nftables firewall type
|
||||||
|
NFTABLES
|
||||||
|
)
|
||||||
|
|
||||||
|
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||||
|
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||||
|
|
||||||
|
// FWType is the type for the firewall type
|
||||||
|
type FWType int
|
||||||
|
|
||||||
|
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
|
func Check() FWType {
|
||||||
|
nf := nftables.Conn{}
|
||||||
|
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||||
|
return NFTABLES
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err == nil {
|
||||||
|
if isIptablesClientAvailable(ip) {
|
||||||
|
ipSupport := IPTABLES
|
||||||
|
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if ip6Err == nil {
|
||||||
|
if isIptablesClientAvailable(ipv6) {
|
||||||
|
ipSupport = IPTABLESWITHV6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ipSupport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
@@ -215,10 +215,12 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
if *input.PreSharedKey != "" {
|
||||||
*input.PreSharedKey, config.PreSharedKey)
|
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
*input.PreSharedKey, config.PreSharedKey)
|
||||||
refresh = true
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.SSHKey == "" {
|
if config.SSHKey == "" {
|
||||||
|
|||||||
@@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
||||||
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
managementURL := "https://test.management.url:33071"
|
managementURL := "https://test.management.url:33071"
|
||||||
adminURL := "https://app.admin.url:443"
|
adminURL := "https://app.admin.url:443"
|
||||||
path := filepath.Join(t.TempDir(), "config.json")
|
path := filepath.Join(t.TempDir(), "config.json")
|
||||||
@@ -63,7 +60,22 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
// case 4: existing config, but new managementURL has been provided -> update config
|
// case 4: new empty pre-shared key config -> fetch it
|
||||||
|
newPreSharedKey := ""
|
||||||
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
AdminURL: adminURL,
|
||||||
|
ConfigPath: path,
|
||||||
|
PreSharedKey: &newPreSharedKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||||
|
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||||
|
|
||||||
|
// case 5: existing config, but new managementURL has been provided -> update config
|
||||||
newManagementURL := "https://test.newManagement.url:33071"
|
newManagementURL := "https://test.newManagement.url:33071"
|
||||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: newManagementURL,
|
ManagementURL: newManagementURL,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -24,7 +25,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||||
|
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunClientMobile with main logic on mobile system
|
||||||
|
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||||
|
// in case of non Android os these variables will be nil
|
||||||
|
mobileDependency := MobileDependency{
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
|
RouteListener: routeListener,
|
||||||
|
HostDNSAddresses: dnsAddresses,
|
||||||
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
}
|
||||||
|
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in case of non Android os these variables will be nil
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||||
md := MobileDependency{
|
|
||||||
TunAdapter: tunAdapter,
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
RouteListener: routeListener,
|
|
||||||
}
|
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
@@ -168,8 +179,6 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
statusRecorder.ClientStart()
|
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
statusRecorder.ClientTeardown()
|
statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
@@ -190,6 +199,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
statusRecorder.ClientStart()
|
||||||
err = backoff.Retry(operation, backOff)
|
err = backoff.Retry(operation, backOff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ import (
|
|||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
Provider string
|
Provider string
|
||||||
ProviderConfig ProviderConfig
|
ProviderConfig DeviceAuthProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
type ProviderConfig struct {
|
type DeviceAuthProviderConfig struct {
|
||||||
// ClientID An IDP application client id
|
// ClientID An IDP application client id
|
||||||
ClientID string
|
ClientID string
|
||||||
// ClientSecret An IDP application client secret
|
// ClientSecret An IDP application client secret
|
||||||
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||||
|
|
||||||
ProviderConfig: ProviderConfig{
|
ProviderConfig: DeviceAuthProviderConfig{
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||||
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||||
}
|
}
|
||||||
|
|
||||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DeviceAuthorizationFlow{}, err
|
return DeviceAuthorizationFlow{}, err
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
return deviceAuthorizationFlow, nil
|
return deviceAuthorizationFlow, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isProviderConfigValid(config ProviderConfig) error {
|
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
if config.Audience == "" {
|
if config.Audience == "" {
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ const (
|
|||||||
fileGeneratedResolvConfSearchBeginContent = "search "
|
fileGeneratedResolvConfSearchBeginContent = "search "
|
||||||
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
|
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
|
||||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
|
||||||
|
"%s\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -91,7 +92,12 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
searchDomains += " " + dConf.domain
|
searchDomains += " " + dConf.domain
|
||||||
appendedDomains++
|
appendedDomains++
|
||||||
}
|
}
|
||||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
|
||||||
|
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Could not read existing resolv.conf")
|
||||||
|
}
|
||||||
|
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
|
||||||
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
|
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = f.restore()
|
err = f.restore()
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -34,7 +32,7 @@ type systemConfigurator struct {
|
|||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
func newHostManager(_ WGIface) (hostManager, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -184,12 +182,11 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||||
primaryServiceKey := s.getPrimaryService()
|
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
||||||
if primaryServiceKey == "" {
|
if primaryServiceKey == "" {
|
||||||
return fmt.Errorf("couldn't find the primary service key")
|
return fmt.Errorf("couldn't find the primary service key")
|
||||||
}
|
}
|
||||||
|
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -198,27 +195,32 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) getPrimaryService() string {
|
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||||
line := buildCommandLine("show", globalIPv4State, "")
|
line := buildCommandLine("show", globalIPv4State, "")
|
||||||
stdinCommands := wrapCommand(line)
|
stdinCommands := wrapCommand(line)
|
||||||
b, err := runSystemConfigCommand(stdinCommands)
|
b, err := runSystemConfigCommand(stdinCommands)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("got error while sending the command: ", err)
|
log.Error("got error while sending the command: ", err)
|
||||||
return ""
|
return "", ""
|
||||||
}
|
}
|
||||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||||
|
primaryService := ""
|
||||||
|
router := ""
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
if strings.Contains(text, "PrimaryService") {
|
if strings.Contains(text, "PrimaryService") {
|
||||||
return strings.TrimSpace(strings.Split(text, ":")[1])
|
primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "Router") {
|
||||||
|
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return primaryService, router
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
|
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||||
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
||||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
|
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
|
||||||
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
||||||
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
||||||
stdinCommands := wrapCommand(addDomainCommand)
|
stdinCommands := wrapCommand(addDomainCommand)
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,7 +25,7 @@ const (
|
|||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -33,7 +31,7 @@ type registryConfigurator struct {
|
|||||||
existingSearchDomains []string
|
existingSearchDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
if m.UpdateDNSServerFunc != nil {
|
if m.UpdateDNSServerFunc != nil {
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ import (
|
|||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@@ -18,7 +17,7 @@ type resolvconf struct {
|
|||||||
ifaceName string
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.Name(),
|
ifaceName: wgInterface.Name(),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -61,7 +60,11 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
appendedDomains++
|
appendedDomains++
|
||||||
}
|
}
|
||||||
|
|
||||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Could not read existing resolv.conf")
|
||||||
|
}
|
||||||
|
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
|
||||||
|
|
||||||
err = r.applyConfig(content)
|
err = r.applyConfig(content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,29 +3,20 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
defaultPort = 53
|
type ReadyListener interface {
|
||||||
customPort = 5053
|
OnReady()
|
||||||
defaultIP = "127.0.0.1"
|
}
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
@@ -33,6 +24,7 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type registeredHandlerMap map[string]handlerWithStop
|
||||||
@@ -42,21 +34,19 @@ type DefaultServer struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
udpFilterHookID string
|
service service
|
||||||
server *dns.Server
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface *iface.WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
listenerIsRunning bool
|
|
||||||
runtimePort int
|
|
||||||
runtimeIP string
|
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig hostDNSConfig
|
currentConfig hostDNSConfig
|
||||||
customAddress *netip.AddrPort
|
|
||||||
enabled bool
|
// permanent related properties
|
||||||
|
permanent bool
|
||||||
|
hostsDnsList []string
|
||||||
|
hostsDnsListLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
@@ -70,9 +60,7 @@ type muxUpdate struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@@ -82,37 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
addrPort = &parsedAddrPort
|
addrPort = &parsedAddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
var dnsService service
|
||||||
|
if wgInterface.IsUserspaceBind() {
|
||||||
|
dnsService = newServiceViaMemory(wgInterface)
|
||||||
|
} else {
|
||||||
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
|
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
||||||
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
|
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||||
|
ds.permanent = true
|
||||||
|
ds.hostsDnsList = hostsDnsList
|
||||||
|
ds.addHostRootZone()
|
||||||
|
setServerDns(ds)
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
server: &dns.Server{
|
service: dnsService,
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
},
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
customAddress: addrPort,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if initialDnsCfg != nil {
|
return defaultServer
|
||||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if wgInterface.IsUserspaceBind() {
|
|
||||||
defaultServer.evelRuntimeAddressForUserspace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultServer, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager. It required to be initialized wginterface
|
// Initialize instantiate host manager and the dns service
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -121,72 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.wgInterface.IsUserspaceBind() {
|
if s.permanent {
|
||||||
s.evalRuntimeAddress()
|
err = s.service.Listen()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.hostManager, err = newHostManager(s.wgInterface)
|
s.hostManager, err = newHostManager(s.wgInterface)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen runs the listener in a go routine
|
|
||||||
func (s *DefaultServer) listen() {
|
|
||||||
// nil check required in unit tests
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
|
||||||
s.udpFilterHookID = s.filterDNSTraffic()
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DnsIP returns the DNS resolver server IP address
|
// DnsIP returns the DNS resolver server IP address
|
||||||
//
|
//
|
||||||
// When kernel space interface used it return real DNS server listener IP address
|
// When kernel space interface used it return real DNS server listener IP address
|
||||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||||
func (s *DefaultServer) DnsIP() string {
|
func (s *DefaultServer) DnsIP() string {
|
||||||
if !s.enabled {
|
return s.service.RuntimeIP()
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return s.runtimeIP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
@@ -202,37 +148,23 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.stopListener()
|
s.service.Stop()
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
// udpFilterHookID here empty only in the unit tests
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
|
s.hostsDnsListLock.Lock()
|
||||||
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
|
defer s.hostsDnsListLock.Unlock()
|
||||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.udpFilterHookID = ""
|
|
||||||
s.listenerIsRunning = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
s.hostsDnsList = hostsDnsList
|
||||||
return nil
|
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||||
|
if ok {
|
||||||
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
s.addHostRootZone()
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := s.server.ShutdownContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
// UpdateDNSServer processes an update received from the management service
|
||||||
@@ -283,12 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be disabled, we stop the listener or fake resolver
|
// is the service should be disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if !update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
if err := s.stopListener(); err != nil {
|
_ = s.service.Listen()
|
||||||
log.Error(err)
|
} else if !s.permanent {
|
||||||
}
|
s.service.Stop()
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.listen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
@@ -299,17 +229,16 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
s.updateMux(muxUpdates)
|
||||||
s.updateLocalResolver(localRecords)
|
s.updateLocalResolver(localRecords)
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
hostUpdate := s.currentConfig
|
||||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||||
hostUpdate.routeAll = false
|
hostUpdate.routeAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -412,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
|
||||||
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.registerMux(update.domain, update.handler)
|
s.service.RegisterMux(update.domain, update.handler)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update.domain == nbdns.RootZone {
|
||||||
|
isContainRootUpdate = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
existingHandler.stop()
|
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||||
s.deregisterMux(key)
|
s.hostsDnsListLock.Lock()
|
||||||
|
s.addHostRootZone()
|
||||||
|
s.hostsDnsListLock.Unlock()
|
||||||
|
existingHandler.stop()
|
||||||
|
} else {
|
||||||
|
existingHandler.stop()
|
||||||
|
s.service.DeregisterMux(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
@@ -490,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
for i, item := range s.currentConfig.domains {
|
for i, item := range s.currentConfig.domains {
|
||||||
if _, found := removeIndex[item.domain]; found {
|
if _, found := removeIndex[item.domain]; found {
|
||||||
s.currentConfig.domains[i].disabled = true
|
s.currentConfig.domains[i].disabled = true
|
||||||
s.deregisterMux(item.domain)
|
s.service.DeregisterMux(item.domain)
|
||||||
removeIndex[item.domain] = i
|
removeIndex[item.domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -507,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.domains[i].disabled = false
|
s.currentConfig.domains[i].disabled = false
|
||||||
s.registerMux(domain, handler)
|
s.service.RegisterMux(domain, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@@ -523,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) filterDNSTraffic() string {
|
func (s *DefaultServer) addHostRootZone() {
|
||||||
filter := s.wgInterface.GetFilter()
|
handler := newUpstreamResolver(s.ctx)
|
||||||
if filter == nil {
|
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||||
log.Error("can't set DNS filter, filter not initialized")
|
for n, ua := range s.hostsDnsList {
|
||||||
return ""
|
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
||||||
}
|
}
|
||||||
|
handler.deactivate = func() {}
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
handler.reactivate = func() {}
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
hook := func(packetData []byte) bool {
|
|
||||||
// Decode the packet
|
|
||||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(udp.Payload); err != nil {
|
|
||||||
log.Tracef("parse DNS request: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := responseWriter{
|
|
||||||
packet: packet,
|
|
||||||
device: s.wgInterface.GetDevice().Device,
|
|
||||||
}
|
|
||||||
go s.dnsMux.ServeDNS(&writer, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
|
|
||||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
|
||||||
s.runtimePort = defaultPort
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) evalRuntimeAddress() {
|
|
||||||
defer func() {
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if s.customAddress != nil {
|
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, port, err := s.getFirstListenerAvailable()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtimeIP = ip
|
|
||||||
s.runtimePort = port
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
|
||||||
// Calculate the last IP in the CIDR range
|
|
||||||
var endIP net.IP
|
|
||||||
for i := 0; i < len(network.IP); i++ {
|
|
||||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert to big.Int
|
|
||||||
endInt := big.NewInt(0)
|
|
||||||
endInt.SetBytes(endIP)
|
|
||||||
|
|
||||||
// subtract fromEnd from the last ip
|
|
||||||
fromEndBig := big.NewInt(int64(fromEnd))
|
|
||||||
resultInt := big.NewInt(0)
|
|
||||||
resultInt.Sub(endInt, fromEndBig)
|
|
||||||
|
|
||||||
return net.IP(resultInt.Bytes()).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
|
||||||
for _, c := range cfg.NameServerGroups {
|
|
||||||
if c.Primary {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
client/internal/dns/server_export.go
Normal file
29
client/internal/dns/server_export.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mutex sync.Mutex
|
||||||
|
server Server
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
||||||
|
func GetServerDns() (Server, error) {
|
||||||
|
mutex.Lock()
|
||||||
|
if server == nil {
|
||||||
|
mutex.Unlock()
|
||||||
|
return nil, fmt.Errorf("DNS server not instantiated yet")
|
||||||
|
}
|
||||||
|
s := server
|
||||||
|
mutex.Unlock()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setServerDns(newServerServer Server) {
|
||||||
|
mutex.Lock()
|
||||||
|
server = newServerServer
|
||||||
|
defer mutex.Unlock()
|
||||||
|
}
|
||||||
24
client/internal/dns/server_export_test.go
Normal file
24
client/internal/dns/server_export_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetServerDns(t *testing.T) {
|
||||||
|
_, err := GetServerDns()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("invalid dns server instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &MockServer{}
|
||||||
|
setServerDns(srv)
|
||||||
|
|
||||||
|
srvB, err := GetServerDns()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid dns server instance: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srvB != srv {
|
||||||
|
t.Errorf("missmatch dns instances")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,14 +11,53 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/miekg/dns"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mocWGIface struct {
|
||||||
|
filter iface.PacketFilter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) Name() string {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) Address() iface.WGAddress {
|
||||||
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||||
|
return w.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) IsUserspaceBind() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||||
|
w.filter = filter
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
{
|
{
|
||||||
Name: "peera.netbird.cloud",
|
Name: "peera.netbird.cloud",
|
||||||
@@ -29,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
log.SetLevel(log.TraceLevel)
|
||||||
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -224,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -242,8 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
// pretend we are running
|
|
||||||
dnsServer.listenerIsRunning = true
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -282,7 +324,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet(nil)
|
newNet, err := stdnet.NewNet(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create stdnet: %v", err)
|
t.Errorf("create stdnet: %v", err)
|
||||||
@@ -316,17 +358,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
t.Errorf("set packet filter: %v", err)
|
t.Errorf("set packet filter: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -421,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||||
|
if err != nil {
|
||||||
dnsServer.hostManager = newNoopHostMocker()
|
t.Fatalf("%v", err)
|
||||||
dnsServer.listen()
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
if !dnsServer.listenerIsRunning {
|
|
||||||
t.Fatal("dns server listener is not running")
|
|
||||||
}
|
}
|
||||||
|
dnsServer.hostManager = newNoopHostMocker()
|
||||||
|
err = dnsServer.service.Listen()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dns server is not running: %s", err)
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@@ -443,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
d := net.Dialer{
|
d := net.Dialer{
|
||||||
Timeout: time.Second * 5,
|
Timeout: time.Second * 5,
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
conn, err := d.DialContext(ctx, network, addr)
|
conn, err := d.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
@@ -478,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
dnsMux: dns.DefaultServeMux,
|
service: newServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@@ -541,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||||
mux := dns.NewServeMux()
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
var parsedAddrPort *netip.AddrPort
|
var dnsList []string
|
||||||
if addrPort != "" {
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
|
||||||
parsed, err := netip.ParseAddrPort(addrPort)
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
}
|
return
|
||||||
parsedAddrPort = &parsed
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||||
|
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||||
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
// check initial state
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
update := nbdns.Config{
|
||||||
Net: "udp",
|
ServiceEnable: true,
|
||||||
Handler: mux,
|
CustomZones: []nbdns.CustomZone{
|
||||||
UDPSize: 65535,
|
{
|
||||||
}
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
},
|
||||||
|
},
|
||||||
ds := &DefaultServer{
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
ctx: ctx,
|
{
|
||||||
ctxCancel: cancel,
|
NameServers: []nbdns.NameServer{
|
||||||
server: dnsServer,
|
{
|
||||||
dnsMux: mux,
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
NSType: nbdns.UDPNameServerType,
|
||||||
localResolver: &localResolver{
|
Port: 53,
|
||||||
registeredMap: make(registrationMap),
|
},
|
||||||
|
},
|
||||||
|
Enabled: true,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
customAddress: parsedAddrPort,
|
|
||||||
}
|
|
||||||
ds.evalRuntimeAddress()
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
err = dnsServer.UpdateDNSServer(1, update)
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
if err != nil {
|
||||||
if err != nil {
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
if lastIP != tt.ip {
|
if err != nil {
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(2, update2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||||
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to initialize wg interface")
|
||||||
|
}
|
||||||
|
defer wgIFace.Close()
|
||||||
|
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
|
// check initial state
|
||||||
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Domains: []string{"customdomain.com"},
|
||||||
|
Primary: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(1, update)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to update dns server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed resolve zone record: %v", err)
|
||||||
|
}
|
||||||
|
if ips[0] != zoneRecords[0].RData {
|
||||||
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
|
}
|
||||||
|
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to resolve: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create stdnet: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build interface wireguard: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("crate and init wireguard interface: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pf, err := uspfilter.Create(wgIface)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.SetFilter(pf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("set packet filter: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDnsResolver(ip string, port int) *net.Resolver {
|
||||||
|
return &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
d := net.Dialer{
|
||||||
|
Timeout: time.Second * 3,
|
||||||
|
}
|
||||||
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
return d.DialContext(ctx, network, addr)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
18
client/internal/dns/service.go
Normal file
18
client/internal/dns/service.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
)
|
||||||
|
|
||||||
|
type service interface {
|
||||||
|
Listen() error
|
||||||
|
Stop()
|
||||||
|
RegisterMux(domain string, handler dns.Handler)
|
||||||
|
DeregisterMux(key string)
|
||||||
|
RuntimePort() int
|
||||||
|
RuntimeIP() string
|
||||||
|
}
|
||||||
193
client/internal/dns/service_listener.go
Normal file
193
client/internal/dns/service_listener.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serviceViaListener struct {
|
||||||
|
wgInterface WGIface
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
customAddr *netip.AddrPort
|
||||||
|
server *dns.Server
|
||||||
|
listenIP string
|
||||||
|
listenPort int
|
||||||
|
listenerIsRunning bool
|
||||||
|
listenerFlagLock sync.Mutex
|
||||||
|
ebpfService ebpfMgr.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
s := &serviceViaListener{
|
||||||
|
wgInterface: wgIface,
|
||||||
|
dnsMux: mux,
|
||||||
|
customAddr: customAddr,
|
||||||
|
server: &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) Listen() error {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.listenIP, s.listenPort, err = s.evalListenAddress()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to eval runtime address: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||||
|
|
||||||
|
if s.shouldApplyPortFwd() {
|
||||||
|
s.ebpfService = ebpf.GetEbpfManagerInstance()
|
||||||
|
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
|
||||||
|
s.ebpfService = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) Stop() {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ebpfService != nil {
|
||||||
|
err = s.ebpfService.FreeDNSFwd()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RuntimePort() int {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.ebpfService != nil {
|
||||||
|
return defaultPort
|
||||||
|
} else {
|
||||||
|
return s.listenPort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) RuntimeIP() string {
|
||||||
|
return s.listenIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
|
||||||
|
if s.customAddr != nil {
|
||||||
|
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getFirstListenerAvailable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
|
||||||
|
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
|
||||||
|
// resolution won't work.
|
||||||
|
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
|
||||||
|
// traffic on port 53 and forward it to a local DNS server running on 5053.
|
||||||
|
func (s *serviceViaListener) shouldApplyPortFwd() bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.customAddr != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.listenPort == defaultPort {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
139
client/internal/dns/service_memory.go
Normal file
139
client/internal/dns/service_memory.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serviceViaMemory struct {
|
||||||
|
wgInterface WGIface
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
runtimeIP string
|
||||||
|
runtimePort int
|
||||||
|
udpFilterHookID string
|
||||||
|
listenerIsRunning bool
|
||||||
|
listenerFlagLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||||
|
s := &serviceViaMemory{
|
||||||
|
wgInterface: wgIface,
|
||||||
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
|
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||||
|
runtimePort: defaultPort,
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) Listen() error {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listenerIsRunning = true
|
||||||
|
|
||||||
|
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) Stop() {
|
||||||
|
s.listenerFlagLock.Lock()
|
||||||
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||||
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.listenerIsRunning = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RuntimePort() int {
|
||||||
|
return s.runtimePort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) RuntimeIP() string {
|
||||||
|
return s.runtimeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||||
|
filter := s.wgInterface.GetFilter()
|
||||||
|
if filter == nil {
|
||||||
|
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
|
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||||
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := func(packetData []byte) bool {
|
||||||
|
// Decode the packet
|
||||||
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
udp := udpLayer.(*layers.UDP)
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(udp.Payload); err != nil {
|
||||||
|
log.Tracef("parse DNS request: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := responseWriter{
|
||||||
|
packet: packet,
|
||||||
|
device: s.wgInterface.GetDevice().Device,
|
||||||
|
}
|
||||||
|
go s.dnsMux.ServeDNS(&writer, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||||
|
// Calculate the last IP in the CIDR range
|
||||||
|
var endIP net.IP
|
||||||
|
for i := 0; i < len(network.IP); i++ {
|
||||||
|
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to big.Int
|
||||||
|
endInt := big.NewInt(0)
|
||||||
|
endInt.SetBytes(endIP)
|
||||||
|
|
||||||
|
// subtract fromEnd from the last ip
|
||||||
|
fromEndBig := big.NewInt(int64(fromEnd))
|
||||||
|
resultInt := big.NewInt(0)
|
||||||
|
resultInt.Sub(endInt, fromEndBig)
|
||||||
|
|
||||||
|
return net.IP(resultInt.Bytes()).String()
|
||||||
|
}
|
||||||
31
client/internal/dns/service_memory_test.go
Normal file
31
client/internal/dns/service_memory_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addr string
|
||||||
|
ip string
|
||||||
|
}{
|
||||||
|
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||||
|
{"192.168.0.0/30", "192.168.0.2"},
|
||||||
|
{"192.168.0.0/16", "192.168.255.254"},
|
||||||
|
{"192.168.0.0/24", "192.168.0.254"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error parsing CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||||
|
if lastIP != tt.ip {
|
||||||
|
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
14
client/internal/dns/wgiface.go
Normal file
14
client/internal/dns/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// WGIface defines subset methods of interface required for manager
|
||||||
|
type WGIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
GetFilter() iface.PacketFilter
|
||||||
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
}
|
||||||
13
client/internal/dns/wgiface_windows.go
Normal file
13
client/internal/dns/wgiface_windows.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// WGIface defines subset methods of interface required for manager
|
||||||
|
type WGIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
GetFilter() iface.PacketFilter
|
||||||
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
GetInterfaceGUIDString() (string, error)
|
||||||
|
}
|
||||||
129
client/internal/ebpf/ebpf/bpf_bpfeb.go
Normal file
129
client/internal/ebpf/ebpf/bpf_bpfeb.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
// Code generated by bpf2go; DO NOT EDIT.
|
||||||
|
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
|
||||||
|
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
|
||||||
|
|
||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/cilium/ebpf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||||
|
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||||
|
reader := bytes.NewReader(_BpfBytes)
|
||||||
|
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return spec, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadBpfObjects loads bpf and converts it into a struct.
|
||||||
|
//
|
||||||
|
// The following types are suitable as obj argument:
|
||||||
|
//
|
||||||
|
// *bpfObjects
|
||||||
|
// *bpfPrograms
|
||||||
|
// *bpfMaps
|
||||||
|
//
|
||||||
|
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||||
|
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||||
|
spec, err := loadBpf()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return spec.LoadAndAssign(obj, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed ebpf.CollectionSpec.Assign.
|
||||||
|
type bpfSpecs struct {
|
||||||
|
bpfProgramSpecs
|
||||||
|
bpfMapSpecs
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed ebpf.CollectionSpec.Assign.
|
||||||
|
type bpfProgramSpecs struct {
|
||||||
|
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed ebpf.CollectionSpec.Assign.
|
||||||
|
type bpfMapSpecs struct {
|
||||||
|
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
|
||||||
|
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
|
||||||
|
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
|
||||||
|
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||||
|
type bpfObjects struct {
|
||||||
|
bpfPrograms
|
||||||
|
bpfMaps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *bpfObjects) Close() error {
|
||||||
|
return _BpfClose(
|
||||||
|
&o.bpfPrograms,
|
||||||
|
&o.bpfMaps,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||||
|
type bpfMaps struct {
|
||||||
|
NbFeatures *ebpf.Map `ebpf:"nb_features"`
|
||||||
|
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
|
||||||
|
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
|
||||||
|
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *bpfMaps) Close() error {
|
||||||
|
return _BpfClose(
|
||||||
|
m.NbFeatures,
|
||||||
|
m.NbMapDnsIp,
|
||||||
|
m.NbMapDnsPort,
|
||||||
|
m.NbWgProxySettingsMap,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||||
|
type bpfPrograms struct {
|
||||||
|
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bpfPrograms) Close() error {
|
||||||
|
return _BpfClose(
|
||||||
|
p.NbXdpProg,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _BpfClose(closers ...io.Closer) error {
|
||||||
|
for _, closer := range closers {
|
||||||
|
if err := closer.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not access this directly.
|
||||||
|
//
|
||||||
|
//go:embed bpf_bpfeb.o
|
||||||
|
var _BpfBytes []byte
|
||||||
BIN
client/internal/ebpf/ebpf/bpf_bpfeb.o
Normal file
BIN
client/internal/ebpf/ebpf/bpf_bpfeb.o
Normal file
Binary file not shown.
129
client/internal/ebpf/ebpf/bpf_bpfel.go
Normal file
129
client/internal/ebpf/ebpf/bpf_bpfel.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
// Code generated by bpf2go; DO NOT EDIT.
|
||||||
|
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
|
||||||
|
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
|
||||||
|
|
||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/cilium/ebpf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||||
|
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||||
|
reader := bytes.NewReader(_BpfBytes)
|
||||||
|
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return spec, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadBpfObjects loads bpf and converts it into a struct.
|
||||||
|
//
|
||||||
|
// The following types are suitable as obj argument:
|
||||||
|
//
|
||||||
|
// *bpfObjects
|
||||||
|
// *bpfPrograms
|
||||||
|
// *bpfMaps
|
||||||
|
//
|
||||||
|
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||||
|
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||||
|
spec, err := loadBpf()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return spec.LoadAndAssign(obj, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed ebpf.CollectionSpec.Assign.
|
||||||
|
type bpfSpecs struct {
|
||||||
|
bpfProgramSpecs
|
||||||
|
bpfMapSpecs
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed ebpf.CollectionSpec.Assign.
|
||||||
|
type bpfProgramSpecs struct {
|
||||||
|
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed ebpf.CollectionSpec.Assign.
|
||||||
|
type bpfMapSpecs struct {
|
||||||
|
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
|
||||||
|
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
|
||||||
|
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
|
||||||
|
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||||
|
type bpfObjects struct {
|
||||||
|
bpfPrograms
|
||||||
|
bpfMaps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *bpfObjects) Close() error {
|
||||||
|
return _BpfClose(
|
||||||
|
&o.bpfPrograms,
|
||||||
|
&o.bpfMaps,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||||
|
type bpfMaps struct {
|
||||||
|
NbFeatures *ebpf.Map `ebpf:"nb_features"`
|
||||||
|
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
|
||||||
|
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
|
||||||
|
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *bpfMaps) Close() error {
|
||||||
|
return _BpfClose(
|
||||||
|
m.NbFeatures,
|
||||||
|
m.NbMapDnsIp,
|
||||||
|
m.NbMapDnsPort,
|
||||||
|
m.NbWgProxySettingsMap,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||||
|
//
|
||||||
|
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||||
|
type bpfPrograms struct {
|
||||||
|
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bpfPrograms) Close() error {
|
||||||
|
return _BpfClose(
|
||||||
|
p.NbXdpProg,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _BpfClose(closers ...io.Closer) error {
|
||||||
|
for _, closer := range closers {
|
||||||
|
if err := closer.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not access this directly.
|
||||||
|
//
|
||||||
|
//go:embed bpf_bpfel.o
|
||||||
|
var _BpfBytes []byte
|
||||||
BIN
client/internal/ebpf/ebpf/bpf_bpfel.o
Normal file
BIN
client/internal/ebpf/ebpf/bpf_bpfel.o
Normal file
Binary file not shown.
51
client/internal/ebpf/ebpf/dns_fwd_linux.go
Normal file
51
client/internal/ebpf/ebpf/dns_fwd_linux.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
mapKeyDNSIP uint32 = 0
|
||||||
|
mapKeyDNSPort uint32 = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||||
|
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
|
||||||
|
tf.lock.Lock()
|
||||||
|
defer tf.lock.Unlock()
|
||||||
|
|
||||||
|
err := tf.loadXdp()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tf.setFeatureFlag(featureFlagDnsForwarder)
|
||||||
|
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tf *GeneralManager) FreeDNSFwd() error {
|
||||||
|
log.Debugf("free ebpf DNS forwarder")
|
||||||
|
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ip2int(ipString string) uint32 {
|
||||||
|
ip := net.ParseIP(ipString)
|
||||||
|
return binary.BigEndian.Uint32(ip.To4())
|
||||||
|
}
|
||||||
116
client/internal/ebpf/ebpf/manager_linux.go
Normal file
116
client/internal/ebpf/ebpf/manager_linux.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cilium/ebpf/link"
|
||||||
|
"github.com/cilium/ebpf/rlimit"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
mapKeyFeatures uint32 = 0
|
||||||
|
|
||||||
|
featureFlagWGProxy = 0b00000001
|
||||||
|
featureFlagDnsForwarder = 0b00000010
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
singleton manager.Manager
|
||||||
|
singletonLock = &sync.Mutex{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
|
||||||
|
|
||||||
|
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
|
||||||
|
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
|
||||||
|
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
|
||||||
|
//
|
||||||
|
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
|
||||||
|
type GeneralManager struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
link link.Link
|
||||||
|
featureFlags uint16
|
||||||
|
bpfObjs bpfObjects
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEbpfManagerInstance return a static eBpf Manager instance
|
||||||
|
func GetEbpfManagerInstance() manager.Manager {
|
||||||
|
singletonLock.Lock()
|
||||||
|
defer singletonLock.Unlock()
|
||||||
|
if singleton != nil {
|
||||||
|
return singleton
|
||||||
|
}
|
||||||
|
singleton = &GeneralManager{}
|
||||||
|
return singleton
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
|
||||||
|
tf.featureFlags = tf.featureFlags | feature
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tf *GeneralManager) loadXdp() error {
|
||||||
|
if tf.link != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// it required for Docker
|
||||||
|
err := rlimit.RemoveMemlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
iFace, err := net.InterfaceByName("lo")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// load pre-compiled programs into the kernel.
|
||||||
|
err = loadBpfObjects(&tf.bpfObjs, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tf.link, err = link.AttachXDP(link.XDPOptions{
|
||||||
|
Program: tf.bpfObjs.NbXdpProg,
|
||||||
|
Interface: iFace.Index,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
_ = tf.bpfObjs.Close()
|
||||||
|
tf.link = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
|
||||||
|
tf.lock.Lock()
|
||||||
|
defer tf.lock.Unlock()
|
||||||
|
tf.featureFlags &^= feature
|
||||||
|
|
||||||
|
if tf.link == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if tf.featureFlags == 0 {
|
||||||
|
return tf.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tf *GeneralManager) close() error {
|
||||||
|
log.Debugf("detach ebpf program ")
|
||||||
|
err := tf.bpfObjs.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to close eBpf objects: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tf.link.Close()
|
||||||
|
tf.link = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
40
client/internal/ebpf/ebpf/manager_linux_test.go
Normal file
40
client/internal/ebpf/ebpf/manager_linux_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager_setFeatureFlag(t *testing.T) {
|
||||||
|
mgr := GeneralManager{}
|
||||||
|
mgr.setFeatureFlag(featureFlagWGProxy)
|
||||||
|
if mgr.featureFlags != 1 {
|
||||||
|
t.Errorf("invalid faeture state")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.setFeatureFlag(featureFlagDnsForwarder)
|
||||||
|
if mgr.featureFlags != 3 {
|
||||||
|
t.Errorf("invalid faeture state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_unsetFeatureFlag(t *testing.T) {
|
||||||
|
mgr := GeneralManager{}
|
||||||
|
mgr.setFeatureFlag(featureFlagWGProxy)
|
||||||
|
mgr.setFeatureFlag(featureFlagDnsForwarder)
|
||||||
|
|
||||||
|
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
if mgr.featureFlags != 2 {
|
||||||
|
t.Errorf("invalid faeture state, expected: %d, got: %d", 2, mgr.featureFlags)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
if mgr.featureFlags != 0 {
|
||||||
|
t.Errorf("invalid faeture state, expected: %d, got: %d", 0, mgr.featureFlags)
|
||||||
|
}
|
||||||
|
}
|
||||||
64
client/internal/ebpf/ebpf/src/dns_fwd.c
Normal file
64
client/internal/ebpf/ebpf/src/dns_fwd.c
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
const __u32 map_key_dns_ip = 0;
|
||||||
|
const __u32 map_key_dns_port = 1;
|
||||||
|
|
||||||
|
struct bpf_map_def SEC("maps") nb_map_dns_ip = {
|
||||||
|
.type = BPF_MAP_TYPE_ARRAY,
|
||||||
|
.key_size = sizeof(__u32),
|
||||||
|
.value_size = sizeof(__u32),
|
||||||
|
.max_entries = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct bpf_map_def SEC("maps") nb_map_dns_port = {
|
||||||
|
.type = BPF_MAP_TYPE_ARRAY,
|
||||||
|
.key_size = sizeof(__u32),
|
||||||
|
.value_size = sizeof(__u16),
|
||||||
|
.max_entries = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
__be32 dns_ip = 0;
|
||||||
|
__be16 dns_port = 0;
|
||||||
|
|
||||||
|
// 13568 is 53 in big endian
|
||||||
|
__be16 GENERAL_DNS_PORT = 13568;
|
||||||
|
|
||||||
|
bool read_settings() {
|
||||||
|
__u16 *port_value;
|
||||||
|
__u32 *ip_value;
|
||||||
|
|
||||||
|
// read dns ip
|
||||||
|
ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip);
|
||||||
|
if(!ip_value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
dns_ip = htonl(*ip_value);
|
||||||
|
|
||||||
|
// read dns port
|
||||||
|
port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port);
|
||||||
|
if (!port_value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
dns_port = htons(*port_value);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) {
|
||||||
|
if (dns_port == 0) {
|
||||||
|
if(!read_settings()){
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
bpf_printk("dns port: %d", ntohs(dns_port));
|
||||||
|
bpf_printk("dns ip: %d", ntohl(dns_ip));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) {
|
||||||
|
udp->dest = dns_port;
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (udp->source == dns_port && ip->saddr == dns_ip) {
|
||||||
|
udp->source = GENERAL_DNS_PORT;
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
66
client/internal/ebpf/ebpf/src/prog.c
Normal file
66
client/internal/ebpf/ebpf/src/prog.c
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
#include <stdbool.h>
|
||||||
|
#include <linux/if_ether.h> // ETH_P_IP
|
||||||
|
#include <linux/udp.h>
|
||||||
|
#include <linux/ip.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
|
#include <linux/bpf.h>
|
||||||
|
#include <bpf/bpf_helpers.h>
|
||||||
|
#include "dns_fwd.c"
|
||||||
|
#include "wg_proxy.c"
|
||||||
|
|
||||||
|
#define bpf_printk(fmt, ...) \
|
||||||
|
({ \
|
||||||
|
char ____fmt[] = fmt; \
|
||||||
|
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
|
||||||
|
})
|
||||||
|
|
||||||
|
const __u16 flag_feature_wg_proxy = 0b01;
|
||||||
|
const __u16 flag_feature_dns_fwd = 0b10;
|
||||||
|
|
||||||
|
const __u32 map_key_features = 0;
|
||||||
|
struct bpf_map_def SEC("maps") nb_features = {
|
||||||
|
.type = BPF_MAP_TYPE_ARRAY,
|
||||||
|
.key_size = sizeof(__u32),
|
||||||
|
.value_size = sizeof(__u16),
|
||||||
|
.max_entries = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
SEC("xdp")
|
||||||
|
int nb_xdp_prog(struct xdp_md *ctx) {
|
||||||
|
__u16 *features;
|
||||||
|
features = bpf_map_lookup_elem(&nb_features, &map_key_features);
|
||||||
|
if (!features) {
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *data = (void *)(long)ctx->data;
|
||||||
|
void *data_end = (void *)(long)ctx->data_end;
|
||||||
|
struct ethhdr *eth = data;
|
||||||
|
struct iphdr *ip = (data + sizeof(struct ethhdr));
|
||||||
|
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
|
||||||
|
|
||||||
|
// return early if not enough data
|
||||||
|
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip non IPv4 packages
|
||||||
|
if (eth->h_proto != htons(ETH_P_IP)) {
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip non UPD packages
|
||||||
|
if (ip->protocol != IPPROTO_UDP) {
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*features & flag_feature_dns_fwd) {
|
||||||
|
xdp_dns_fwd(ip, udp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*features & flag_feature_wg_proxy) {
|
||||||
|
xdp_wg_proxy(ip, udp);
|
||||||
|
}
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
char _license[] SEC("license") = "GPL";
|
||||||
54
client/internal/ebpf/ebpf/src/wg_proxy.c
Normal file
54
client/internal/ebpf/ebpf/src/wg_proxy.c
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
const __u32 map_key_proxy_port = 0;
|
||||||
|
const __u32 map_key_wg_port = 1;
|
||||||
|
|
||||||
|
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
|
||||||
|
.type = BPF_MAP_TYPE_ARRAY,
|
||||||
|
.key_size = sizeof(__u32),
|
||||||
|
.value_size = sizeof(__u16),
|
||||||
|
.max_entries = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
__u16 proxy_port = 0;
|
||||||
|
__u16 wg_port = 0;
|
||||||
|
|
||||||
|
bool read_port_settings() {
|
||||||
|
__u16 *value;
|
||||||
|
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
|
||||||
|
if (!value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy_port = *value;
|
||||||
|
|
||||||
|
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
|
||||||
|
if (!value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
wg_port = htons(*value);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) {
|
||||||
|
if (proxy_port == 0 || wg_port == 0) {
|
||||||
|
if (!read_port_settings()){
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2130706433 = 127.0.0.1
|
||||||
|
if (ip->daddr != htonl(2130706433)) {
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (udp->source != wg_port){
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
|
|
||||||
|
__be16 new_src_port = udp->dest;
|
||||||
|
__be16 new_dst_port = htons(proxy_port);
|
||||||
|
udp->dest = new_dst_port;
|
||||||
|
udp->source = new_src_port;
|
||||||
|
return XDP_PASS;
|
||||||
|
}
|
||||||
41
client/internal/ebpf/ebpf/wg_proxy_linux.go
Normal file
41
client/internal/ebpf/ebpf/wg_proxy_linux.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package ebpf
|
||||||
|
|
||||||
|
import log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
const (
|
||||||
|
mapKeyProxyPort uint32 = 0
|
||||||
|
mapKeyWgPort uint32 = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error {
|
||||||
|
log.Debugf("load ebpf WG proxy")
|
||||||
|
tf.lock.Lock()
|
||||||
|
defer tf.lock.Unlock()
|
||||||
|
|
||||||
|
err := tf.loadXdp()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tf.setFeatureFlag(featureFlagWGProxy)
|
||||||
|
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tf *GeneralManager) FreeWGProxy() error {
|
||||||
|
log.Debugf("free ebpf WG proxy")
|
||||||
|
return tf.unsetFeatureFlag(featureFlagWGProxy)
|
||||||
|
}
|
||||||
15
client/internal/ebpf/instantiater_linux.go
Normal file
15
client/internal/ebpf/instantiater_linux.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ebpf/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal
|
||||||
|
// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant
|
||||||
|
// panic on older Android version.
|
||||||
|
func GetEbpfManagerInstance() manager.Manager {
|
||||||
|
return ebpf.GetEbpfManagerInstance()
|
||||||
|
}
|
||||||
10
client/internal/ebpf/instantiater_nonlinux.go
Normal file
10
client/internal/ebpf/instantiater_nonlinux.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package ebpf
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
|
|
||||||
|
// GetEbpfManagerInstance return error because ebpf is not supported on all os
|
||||||
|
func GetEbpfManagerInstance() manager.Manager {
|
||||||
|
panic("unsupported os")
|
||||||
|
}
|
||||||
9
client/internal/ebpf/manager/manager.go
Normal file
9
client/internal/ebpf/manager/manager.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
||||||
|
type Manager interface {
|
||||||
|
LoadDNSFwd(ip string, dnsPort int) error
|
||||||
|
FreeDNSFwd() error
|
||||||
|
LoadWgProxy(proxyPort, wgPort int) error
|
||||||
|
FreeWGProxy() error
|
||||||
|
}
|
||||||
@@ -20,8 +20,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@@ -101,7 +101,8 @@ type Engine struct {
|
|||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
|
wgProxyFactory *wgproxy.Factory
|
||||||
|
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *bind.UniversalUDPMuxDefault
|
||||||
udpMuxConn io.Closer
|
udpMuxConn io.Closer
|
||||||
@@ -132,6 +133,7 @@ func NewEngine(
|
|||||||
signalClient signal.Client, mgmClient mgm.Client,
|
signalClient signal.Client, mgmClient mgm.Client,
|
||||||
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@@ -146,6 +148,7 @@ func NewEngine(
|
|||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
sshServerFunc: nbssh.DefaultSSHServer,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,23 +193,25 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var routes []*route.Route
|
var routes []*route.Route
|
||||||
var dnsCfg *nbdns.Config
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
routes, dnsCfg, err = e.readInitialSettings()
|
routes, err = e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
if e.dnsServer == nil {
|
||||||
|
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
|
||||||
if e.dnsServer == nil {
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// todo fix custom address
|
// todo fix custom address
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
|
if e.dnsServer == nil {
|
||||||
if err != nil {
|
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||||
e.close()
|
if err != nil {
|
||||||
return err
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
||||||
@@ -280,7 +285,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
||||||
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") {
|
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||||
modified = append(modified, p)
|
modified = append(modified, p)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -793,9 +798,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
|||||||
|
|
||||||
// we might have received new STUN and TURN servers meanwhile, so update them
|
// we might have received new STUN and TURN servers meanwhile, so update them
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
conf := conn.GetConf()
|
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||||
conf.StunTurn = append(e.STUNs, e.TURNs...)
|
|
||||||
conn.UpdateConf(conf)
|
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
err := conn.Open()
|
err := conn.Open()
|
||||||
@@ -824,9 +827,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
stunTurn = append(stunTurn, e.STUNs...)
|
stunTurn = append(stunTurn, e.STUNs...)
|
||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
|
|
||||||
proxyConfig := proxy.Config{
|
wgConfig := peer.WgConfig{
|
||||||
RemoteKey: pubKey,
|
RemoteKey: pubKey,
|
||||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
WgListenPort: e.config.WgPort,
|
||||||
WgInterface: e.wgInterface,
|
WgInterface: e.wgInterface,
|
||||||
AllowedIps: allowedIPs,
|
AllowedIps: allowedIPs,
|
||||||
PreSharedKey: e.config.PreSharedKey,
|
PreSharedKey: e.config.PreSharedKey,
|
||||||
@@ -843,13 +846,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
UDPMux: e.udpMux.UDPMuxDefault,
|
UDPMux: e.udpMux.UDPMuxDefault,
|
||||||
UDPMuxSrflx: e.udpMux,
|
UDPMuxSrflx: e.udpMux,
|
||||||
ProxyConfig: proxyConfig,
|
WgConfig: wgConfig,
|
||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -989,14 +992,12 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
|
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if externalIP != nil {
|
mappedIP := externalIP.String()
|
||||||
mappedIP := externalIP.String()
|
if internalIP != nil {
|
||||||
if internalIP != nil {
|
mappedIP = mappedIP + "/" + internalIP.String()
|
||||||
mappedIP = mappedIP + "/" + internalIP.String()
|
|
||||||
}
|
|
||||||
mappedIPs = append(mappedIPs, mappedIP)
|
|
||||||
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
|
|
||||||
}
|
}
|
||||||
|
mappedIPs = append(mappedIPs, mappedIP)
|
||||||
|
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
|
||||||
}
|
}
|
||||||
if len(mappedIPs) != len(e.config.NATExternalIPs) {
|
if len(mappedIPs) != len(e.config.NATExternalIPs) {
|
||||||
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")
|
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")
|
||||||
@@ -1006,6 +1007,10 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
|
if err := e.wgProxyFactory.Free(); err != nil {
|
||||||
|
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
@@ -1045,14 +1050,13 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
|
||||||
netMap, err := e.mgmClient.GetNetworkMap()
|
netMap, err := e.mgmClient.GetNetworkMap()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
return routes, nil
|
||||||
return routes, &dnsCfg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
|
|||||||
@@ -367,9 +367,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||||
}
|
}
|
||||||
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
|
expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
|
||||||
if conn.GetConf().ProxyConfig.AllowedIps != expectedAllowedIPs {
|
if conn.WgConfig().AllowedIps != expectedAllowedIPs {
|
||||||
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
|
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
|
||||||
expectedAllowedIPs, conn.GetConf().ProxyConfig.AllowedIps)
|
expectedAllowedIPs, conn.WgConfig().AllowedIps)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -1046,7 +1046,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
peersUpdateManager := server.NewPeersUpdateManager()
|
peersUpdateManager := server.NewPeersUpdateManager()
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", nil
|
return nil, "", err
|
||||||
}
|
}
|
||||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore)
|
eventStore)
|
||||||
@@ -1054,7 +1054,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@@ -8,7 +9,9 @@ import (
|
|||||||
|
|
||||||
// MobileDependency collect all dependencies for mobile platform
|
// MobileDependency collect all dependencies for mobile platform
|
||||||
type MobileDependency struct {
|
type MobileDependency struct {
|
||||||
TunAdapter iface.TunAdapter
|
TunAdapter iface.TunAdapter
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
RouteListener routemanager.RouteListener
|
RouteListener routemanager.RouteListener
|
||||||
|
HostDNSAddresses []string
|
||||||
|
DnsReadyListener dns.ReadyListener
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,286 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OAuthClient is a OAuth client interface for various idp providers
|
|
||||||
type OAuthClient interface {
|
|
||||||
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
|
||||||
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
|
||||||
GetClientID(ctx context.Context) string
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPClient http client interface for API calls
|
|
||||||
type HTTPClient interface {
|
|
||||||
Do(req *http.Request) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceAuthInfo holds information for the OAuth device login flow
|
|
||||||
type DeviceAuthInfo struct {
|
|
||||||
DeviceCode string `json:"device_code"`
|
|
||||||
UserCode string `json:"user_code"`
|
|
||||||
VerificationURI string `json:"verification_uri"`
|
|
||||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
Interval int `json:"interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// HostedGrantType grant type for device flow on Hosted
|
|
||||||
const (
|
|
||||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
|
||||||
HostedRefreshGrant = "refresh_token"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Hosted client
|
|
||||||
type Hosted struct {
|
|
||||||
providerConfig ProviderConfig
|
|
||||||
|
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
|
||||||
type RequestDeviceCodePayload struct {
|
|
||||||
Audience string `json:"audience"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
Scope string `json:"scope"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRequestPayload used for requesting the auth0 token
|
|
||||||
type TokenRequestPayload struct {
|
|
||||||
GrantType string `json:"grant_type"`
|
|
||||||
DeviceCode string `json:"device_code,omitempty"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRequestResponse used for parsing Hosted token's response
|
|
||||||
type TokenRequestResponse struct {
|
|
||||||
Error string `json:"error"`
|
|
||||||
ErrorDescription string `json:"error_description"`
|
|
||||||
TokenInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// Claims used when validating the access token
|
|
||||||
type Claims struct {
|
|
||||||
Audience interface{} `json:"aud"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenInfo holds information of issued access token
|
|
||||||
type TokenInfo struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
UseIDToken bool `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
|
||||||
func (t TokenInfo) GetTokenToUse() string {
|
|
||||||
if t.UseIDToken {
|
|
||||||
return t.IDToken
|
|
||||||
}
|
|
||||||
return t.AccessToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHostedDeviceFlow returns an Hosted OAuth client
|
|
||||||
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Hosted{
|
|
||||||
providerConfig: config,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientID returns the provider client id
|
|
||||||
func (h *Hosted) GetClientID(ctx context.Context) string {
|
|
||||||
return h.providerConfig.ClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
|
||||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("client_id", h.providerConfig.ClientID)
|
|
||||||
form.Add("audience", h.providerConfig.Audience)
|
|
||||||
form.Add("scope", h.providerConfig.Scope)
|
|
||||||
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
|
|
||||||
strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
res, err := h.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer res.Body.Close()
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceCode := DeviceAuthInfo{}
|
|
||||||
err = json.Unmarshal(body, &deviceCode)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
|
||||||
if deviceCode.VerificationURIComplete == "" {
|
|
||||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceCode, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("client_id", h.providerConfig.ClientID)
|
|
||||||
form.Add("grant_type", HostedGrantType)
|
|
||||||
form.Add("device_code", info.DeviceCode)
|
|
||||||
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
res, err := h.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode > 499 {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenResponse := TokenRequestResponse{}
|
|
||||||
err = json.Unmarshal(body, &tokenResponse)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
|
||||||
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
|
||||||
ticker := time.NewTicker(interval)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return TokenInfo{}, ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
|
|
||||||
tokenResponse, err := h.requestToken(info)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenResponse.Error != "" {
|
|
||||||
if tokenResponse.Error == "authorization_pending" {
|
|
||||||
continue
|
|
||||||
} else if tokenResponse.Error == "slow_down" {
|
|
||||||
interval = interval + (3 * time.Second)
|
|
||||||
ticker.Reset(interval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenInfo := TokenInfo{
|
|
||||||
AccessToken: tokenResponse.AccessToken,
|
|
||||||
TokenType: tokenResponse.TokenType,
|
|
||||||
RefreshToken: tokenResponse.RefreshToken,
|
|
||||||
IDToken: tokenResponse.IDToken,
|
|
||||||
ExpiresIn: tokenResponse.ExpiresIn,
|
|
||||||
UseIDToken: h.providerConfig.UseIDToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isValidAccessToken is a simple validation of the access token
|
|
||||||
func isValidAccessToken(token string, audience string) error {
|
|
||||||
if token == "" {
|
|
||||||
return fmt.Errorf("token received is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
encodedClaims := strings.Split(token, ".")[1]
|
|
||||||
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := Claims{}
|
|
||||||
err = json.Unmarshal(claimsString, &claims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Audience == nil {
|
|
||||||
return fmt.Errorf("required token field audience is absent")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Audience claim of JWT can be a string or an array of strings
|
|
||||||
typ := reflect.TypeOf(claims.Audience)
|
|
||||||
switch typ.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
if claims.Audience == audience {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case reflect.Slice:
|
|
||||||
for _, aud := range claims.Audience.([]interface{}) {
|
|
||||||
if audience == aud {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("invalid JWT token audience field")
|
|
||||||
}
|
|
||||||
@@ -10,9 +10,10 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
@@ -23,8 +24,18 @@ import (
|
|||||||
const (
|
const (
|
||||||
iceKeepAliveDefault = 4 * time.Second
|
iceKeepAliveDefault = 4 * time.Second
|
||||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||||
|
|
||||||
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WgConfig struct {
|
||||||
|
WgListenPort int
|
||||||
|
RemoteKey string
|
||||||
|
WgInterface *iface.WGIface
|
||||||
|
AllowedIps string
|
||||||
|
PreSharedKey *wgtypes.Key
|
||||||
|
}
|
||||||
|
|
||||||
// ConnConfig is a peer Connection configuration
|
// ConnConfig is a peer Connection configuration
|
||||||
type ConnConfig struct {
|
type ConnConfig struct {
|
||||||
|
|
||||||
@@ -43,7 +54,7 @@ type ConnConfig struct {
|
|||||||
|
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
|
||||||
ProxyConfig proxy.Config
|
WgConfig WgConfig
|
||||||
|
|
||||||
UDPMux ice.UDPMux
|
UDPMux ice.UDPMux
|
||||||
UDPMuxSrflx ice.UniversalUDPMux
|
UDPMuxSrflx ice.UniversalUDPMux
|
||||||
@@ -98,7 +109,9 @@ type Conn struct {
|
|||||||
|
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
|
|
||||||
proxy proxy.Proxy
|
wgProxyFactory *wgproxy.Factory
|
||||||
|
wgProxy wgproxy.Proxy
|
||||||
|
|
||||||
remoteModeCh chan ModeMessage
|
remoteModeCh chan ModeMessage
|
||||||
meta meta
|
meta meta
|
||||||
|
|
||||||
@@ -122,14 +135,19 @@ func (conn *Conn) GetConf() ConnConfig {
|
|||||||
return conn.config
|
return conn.config
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConf updates the connection config
|
// WgConfig returns the WireGuard config
|
||||||
func (conn *Conn) UpdateConf(conf ConnConfig) {
|
func (conn *Conn) WgConfig() WgConfig {
|
||||||
conn.config = conf
|
return conn.config.WgConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStunTurn update the turn and stun addresses
|
||||||
|
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
|
||||||
|
conn.config.StunTurn = turnStun
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
config: config,
|
config: config,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
@@ -139,6 +157,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
|
|||||||
remoteAnswerCh: make(chan OfferAnswer),
|
remoteAnswerCh: make(chan OfferAnswer),
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
remoteModeCh: make(chan ModeMessage, 1),
|
remoteModeCh: make(chan ModeMessage, 1),
|
||||||
|
wgProxyFactory: wgProxyFactory,
|
||||||
adapter: adapter,
|
adapter: adapter,
|
||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -215,12 +234,12 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
|
|||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open() error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
peerState := State{
|
||||||
|
PubKey: conn.config.Key,
|
||||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatusUpdate: time.Now(),
|
||||||
peerState.ConnStatus = conn.status
|
ConnStatus: conn.status,
|
||||||
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||||
@@ -275,10 +294,11 @@ func (conn *Conn) Open() error {
|
|||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
peerState = State{PubKey: conn.config.Key}
|
peerState = State{
|
||||||
|
PubKey: conn.config.Key,
|
||||||
peerState.ConnStatus = conn.status
|
ConnStatus: conn.status,
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||||
@@ -309,19 +329,12 @@ func (conn *Conn) Open() error {
|
|||||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||||
}
|
}
|
||||||
// the ice connection has been established successfully so we are ready to start the proxy
|
// the ice connection has been established successfully so we are ready to start the proxy
|
||||||
err = conn.startProxy(remoteConn, remoteWgPort)
|
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
|
||||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
|
||||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
|
||||||
// direct Wireguard connection
|
|
||||||
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||||
select {
|
select {
|
||||||
@@ -338,54 +351,60 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
|||||||
return candidate.Type() == ice.CandidateTypeRelay
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
var pair *ice.CandidatePair
|
|
||||||
pair, err := conn.agent.GetSelectedCandidatePair()
|
pair, err := conn.agent.GetSelectedCandidatePair()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
var endpoint net.Addr
|
||||||
p := conn.getProxy(pair, remoteWgPort)
|
if isRelayCandidate(pair.Local) {
|
||||||
conn.proxy = p
|
log.Debugf("setup relay connection")
|
||||||
err = p.Start(remoteConn)
|
conn.wgProxy = conn.wgProxyFactory.GetProxy()
|
||||||
|
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
||||||
|
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||||
|
endpoint = remoteConn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
|
|
||||||
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if conn.wgProxy != nil {
|
||||||
|
_ = conn.wgProxy.CloseConn()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status
|
peerState := State{
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
PubKey: conn.config.Key,
|
||||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
ConnStatus: conn.status,
|
||||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
ConnStatusUpdate: time.Now(),
|
||||||
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
|
Direct: !isRelayCandidate(pair.Local),
|
||||||
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
}
|
}
|
||||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
|
||||||
|
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return endpoint, nil
|
||||||
}
|
|
||||||
|
|
||||||
// todo rename this method and the proxy package to something more appropriate
|
|
||||||
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
|
||||||
if isRelayCandidate(pair.Local) {
|
|
||||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
|
|
||||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
|
||||||
|
|
||||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
@@ -414,22 +433,22 @@ func (conn *Conn) cleanup() error {
|
|||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
var err1, err2, err3 error
|
||||||
if conn.agent != nil {
|
if conn.agent != nil {
|
||||||
err := conn.agent.Close()
|
err1 = conn.agent.Close()
|
||||||
if err != nil {
|
if err1 == nil {
|
||||||
return err
|
conn.agent = nil
|
||||||
}
|
}
|
||||||
conn.agent = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.proxy != nil {
|
if conn.wgProxy != nil {
|
||||||
err := conn.proxy.Close()
|
err2 = conn.wgProxy.CloseConn()
|
||||||
if err != nil {
|
conn.wgProxy = nil
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.proxy = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
|
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
conn.notifyDisconnected()
|
conn.notifyDisconnected()
|
||||||
conn.notifyDisconnected = nil
|
conn.notifyDisconnected = nil
|
||||||
@@ -437,10 +456,11 @@ func (conn *Conn) cleanup() error {
|
|||||||
|
|
||||||
conn.status = StatusDisconnected
|
conn.status = StatusDisconnected
|
||||||
|
|
||||||
peerState := State{PubKey: conn.config.Key}
|
peerState := State{
|
||||||
peerState.ConnStatus = conn.status
|
PubKey: conn.config.Key,
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
ConnStatus: conn.status,
|
||||||
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||||
@@ -449,8 +469,13 @@ func (conn *Conn) cleanup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
||||||
|
if err1 != nil {
|
||||||
return nil
|
return err1
|
||||||
|
}
|
||||||
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
return err3
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
|
||||||
|
|||||||
@@ -5,12 +5,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +19,6 @@ var connConf = ConnConfig{
|
|||||||
StunTurn: []*ice.URL{},
|
StunTurn: []*ice.URL{},
|
||||||
InterfaceBlackList: nil,
|
InterfaceBlackList: nil,
|
||||||
Timeout: time.Second,
|
Timeout: time.Second,
|
||||||
ProxyConfig: proxy.Config{},
|
|
||||||
LocalWgPort: 51820,
|
LocalWgPort: 51820,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,7 +35,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
conn, err := NewConn(connConf, nil, nil, nil)
|
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||||
|
defer func() {
|
||||||
|
_ = wgProxyFactory.Free()
|
||||||
|
}()
|
||||||
|
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -48,8 +50,11 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
|
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
defer func() {
|
||||||
|
_ = wgProxyFactory.Free()
|
||||||
|
}()
|
||||||
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -82,8 +87,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
|
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
defer func() {
|
||||||
|
_ = wgProxyFactory.Free()
|
||||||
|
}()
|
||||||
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -115,8 +123,11 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
|
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
defer func() {
|
||||||
|
_ = wgProxyFactory.Free()
|
||||||
|
}()
|
||||||
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -142,8 +153,11 @@ func TestConn_Status(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_Close(t *testing.T) {
|
func TestConn_Close(t *testing.T) {
|
||||||
|
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
defer func() {
|
||||||
|
_ = wgProxyFactory.Free()
|
||||||
|
}()
|
||||||
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type notifier struct {
|
|||||||
listener Listener
|
listener Listener
|
||||||
currentClientState bool
|
currentClientState bool
|
||||||
lastNotification int
|
lastNotification int
|
||||||
|
lastNumberOfPeers int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNotifier() *notifier {
|
func newNotifier() *notifier {
|
||||||
@@ -29,6 +30,7 @@ func (n *notifier) setListener(listener Listener) {
|
|||||||
|
|
||||||
n.serverStateLock.Lock()
|
n.serverStateLock.Lock()
|
||||||
n.notifyListener(listener, n.lastNotification)
|
n.notifyListener(listener, n.lastNotification)
|
||||||
|
listener.OnPeersListChanged(n.lastNumberOfPeers)
|
||||||
n.serverStateLock.Unlock()
|
n.serverStateLock.Unlock()
|
||||||
|
|
||||||
n.listener = listener
|
n.listener = listener
|
||||||
@@ -59,7 +61,7 @@ func (n *notifier) clientStart() {
|
|||||||
n.serverStateLock.Lock()
|
n.serverStateLock.Lock()
|
||||||
defer n.serverStateLock.Unlock()
|
defer n.serverStateLock.Unlock()
|
||||||
n.currentClientState = true
|
n.currentClientState = true
|
||||||
n.lastNotification = stateConnected
|
n.lastNotification = stateConnecting
|
||||||
n.notify(n.lastNotification)
|
n.notify(n.lastNotification)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +114,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
|||||||
return stateConnected
|
return stateConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
if !managementConn && !signalConn {
|
if !managementConn && !signalConn && !n.currentClientState {
|
||||||
return stateDisconnected
|
return stateDisconnected
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,6 +126,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *notifier) peerListChanged(numOfPeers int) {
|
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||||
|
n.lastNumberOfPeers = numOfPeers
|
||||||
n.listenersLock.Lock()
|
n.listenersLock.Lock()
|
||||||
defer n.listenersLock.Unlock()
|
defer n.listenersLock.Unlock()
|
||||||
if n.listener == nil {
|
if n.listener == nil {
|
||||||
|
|||||||
@@ -353,9 +353,13 @@ func (d *Status) onConnectionChanged() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) notifyPeerListChanged() {
|
func (d *Status) notifyPeerListChanged() {
|
||||||
d.notifier.peerListChanged(len(d.peers) + len(d.offlinePeers))
|
d.notifier.peerListChanged(d.numOfPeers())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) notifyAddressChanged() {
|
func (d *Status) notifyAddressChanged() {
|
||||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) numOfPeers() int {
|
||||||
|
return len(d.peers) + len(d.offlinePeers)
|
||||||
|
}
|
||||||
|
|||||||
128
client/internal/pkce_auth.go
Normal file
128
client/internal/pkce_auth.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||||
|
type PKCEAuthorizationFlow struct {
|
||||||
|
ProviderConfig PKCEAuthProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||||
|
type PKCEAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// RedirectURL handles authorization code from IDP manager
|
||||||
|
RedirectURLs []string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||||
|
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
|
||||||
|
// validate our peer's Wireguard PRIVATE key
|
||||||
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||||
|
return PKCEAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var mgmTLSEnabled bool
|
||||||
|
if mgmURL.Scheme == "https" {
|
||||||
|
mgmTLSEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||||
|
return PKCEAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = mgmClient.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to close the Management service client %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return PKCEAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
|
return PKCEAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||||
|
return PKCEAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authFlow := PKCEAuthorizationFlow{
|
||||||
|
ProviderConfig: PKCEAuthProviderConfig{
|
||||||
|
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||||
|
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||||
|
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||||
|
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||||
|
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||||
|
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||||
|
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||||
|
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||||
|
if err != nil {
|
||||||
|
return PKCEAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return authFlow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||||
|
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.AuthorizationEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||||
|
}
|
||||||
|
if config.RedirectURLs == nil {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DummyProxy just sends pings to the RemoteKey peer and reads responses
|
|
||||||
type DummyProxy struct {
|
|
||||||
conn net.Conn
|
|
||||||
remote string
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDummyProxy(remote string) *DummyProxy {
|
|
||||||
p := &DummyProxy{remote: remote}
|
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Close() error {
|
|
||||||
p.cancel()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Start(remoteConn net.Conn) error {
|
|
||||||
p.conn = remoteConn
|
|
||||||
go func() {
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err := p.conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, err := p.conn.Write([]byte("hello"))
|
|
||||||
//log.Debugf("sent ping to %s", p.remote)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DummyProxy) Type() Type {
|
|
||||||
return TypeDummy
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
|
||||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
|
||||||
type NoProxy struct {
|
|
||||||
config Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNoProxy creates a new NoProxy with a provided config
|
|
||||||
func NewNoProxy(config Config) *NoProxy {
|
|
||||||
return &NoProxy{config: config}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close removes peer from the WireGuard interface
|
|
||||||
func (p *NoProxy) Close() error {
|
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start just updates WireGuard peer with the remote address
|
|
||||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
|
||||||
|
|
||||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
addr, p.config.PreSharedKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *NoProxy) Type() Type {
|
|
||||||
return TypeNoProxy
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultWgKeepAlive = 25 * time.Second
|
|
||||||
|
|
||||||
type Type string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
|
||||||
TypeWireGuard Type = "WireGuard"
|
|
||||||
TypeDummy Type = "Dummy"
|
|
||||||
TypeNoProxy Type = "NoProxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
WgListenAddr string
|
|
||||||
RemoteKey string
|
|
||||||
WgInterface *iface.WGIface
|
|
||||||
AllowedIps string
|
|
||||||
PreSharedKey *wgtypes.Key
|
|
||||||
}
|
|
||||||
|
|
||||||
type Proxy interface {
|
|
||||||
io.Closer
|
|
||||||
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
|
||||||
Start(remoteConn net.Conn) error
|
|
||||||
Type() Type
|
|
||||||
}
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WireGuardProxy proxies
|
|
||||||
type WireGuardProxy struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
config Config
|
|
||||||
|
|
||||||
remoteConn net.Conn
|
|
||||||
localConn net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
|
||||||
p := &WireGuardProxy{config: config}
|
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WireGuardProxy) updateEndpoint() error {
|
|
||||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// add local proxy connection as a Wireguard peer
|
|
||||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
udpAddr, p.config.PreSharedKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
|
||||||
p.remoteConn = remoteConn
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = p.updateEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go p.proxyToRemote()
|
|
||||||
go p.proxyToLocal()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WireGuardProxy) Close() error {
|
|
||||||
p.cancel()
|
|
||||||
if c := p.localConn; c != nil {
|
|
||||||
err := p.localConn.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
|
||||||
// blocks
|
|
||||||
func (p *WireGuardProxy) proxyToRemote() {
|
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
n, err := p.localConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = p.remoteConn.Write(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
|
||||||
// blocks
|
|
||||||
func (p *WireGuardProxy) proxyToLocal() {
|
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
n, err := p.remoteConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = p.localConn.Write(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WireGuardProxy) Type() Type {
|
|
||||||
return TypeWireGuard
|
|
||||||
}
|
|
||||||
@@ -155,7 +155,10 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
|||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil || state.ConnStatus != peer.StatusConnected {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if state.ConnStatus != peer.StatusConnected {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user