mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
3 Commits
sqlite-asy
...
separate_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b43d7e8ef | ||
|
|
dcc83c8741 | ||
|
|
d56669ec2e |
@@ -1,15 +0,0 @@
|
|||||||
FROM golang:1.21-bullseye
|
|
||||||
|
|
||||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
|
||||||
&& apt-get -y install --no-install-recommends\
|
|
||||||
gettext-base=0.21-4 \
|
|
||||||
iptables=1.8.7-1 \
|
|
||||||
libgl1-mesa-dev=20.3.5-1 \
|
|
||||||
xorg-dev=1:7.7+22 \
|
|
||||||
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& go install -v golang.org/x/tools/gopls@latest
|
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "NetBird",
|
|
||||||
"build": {
|
|
||||||
"context": "..",
|
|
||||||
"dockerfile": "Dockerfile"
|
|
||||||
},
|
|
||||||
"features": {
|
|
||||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
|
||||||
"ghcr.io/devcontainers/features/go:1": {
|
|
||||||
"version": "1.21"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
|
||||||
"capAdd": [
|
|
||||||
"NET_ADMIN",
|
|
||||||
"SYS_ADMIN",
|
|
||||||
"SYS_RESOURCE"
|
|
||||||
],
|
|
||||||
"privileged": true
|
|
||||||
}
|
|
||||||
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
|||||||
*.go text eol=lf
|
|
||||||
48
.github/workflows/android-build-validation.yml
vendored
48
.github/workflows/android-build-validation.yml
vendored
@@ -1,48 +0,0 @@
|
|||||||
name: Android build validation
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: "1.21.x"
|
|
||||||
- name: Setup Android SDK
|
|
||||||
uses: android-actions/setup-android@v3
|
|
||||||
with:
|
|
||||||
cmdline-tools-version: 8512546
|
|
||||||
- name: Setup Java
|
|
||||||
uses: actions/setup-java@v3
|
|
||||||
with:
|
|
||||||
java-version: "11"
|
|
||||||
distribution: "adopt"
|
|
||||||
- name: NDK Cache
|
|
||||||
id: ndk-cache
|
|
||||||
uses: actions/cache@v3
|
|
||||||
with:
|
|
||||||
path: /usr/local/lib/android/sdk/ndk
|
|
||||||
key: ndk-cache-23.1.7779620
|
|
||||||
- name: Setup NDK
|
|
||||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
|
||||||
- name: install gomobile
|
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
|
||||||
- name: gomobile init
|
|
||||||
run: gomobile init
|
|
||||||
- name: build android nebtird lib
|
|
||||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
|
||||||
env:
|
|
||||||
CGO_ENABLED: 0
|
|
||||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
|
||||||
13
.github/workflows/golang-test-darwin.yml
vendored
13
.github/workflows/golang-test-darwin.yml
vendored
@@ -12,20 +12,17 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
store: ['jsonfile', 'sqlite']
|
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21.x"
|
go-version: "1.20.x"
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v2
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -36,4 +33,4 @@ jobs:
|
|||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
run: go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||||
|
|||||||
40
.github/workflows/golang-test-linux.yml
vendored
40
.github/workflows/golang-test-linux.yml
vendored
@@ -15,17 +15,16 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
arch: ['386','amd64']
|
arch: ['386','amd64']
|
||||||
store: ['jsonfile', 'sqlite']
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21.x"
|
go-version: "1.20.x"
|
||||||
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v2
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -33,7 +32,7 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||||
@@ -42,18 +41,19 @@ jobs:
|
|||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21.x"
|
go-version: "1.20.x"
|
||||||
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v2
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -61,10 +61,10 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
@@ -78,11 +78,8 @@ jobs:
|
|||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
- name: Generate Engine Test bin
|
||||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
- name: Generate Peer Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||||
@@ -95,17 +92,12 @@ jobs:
|
|||||||
- name: Run Iface tests in docker
|
- name: Run Iface tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
- name: Run RouteManager tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run nftables Manager tests in docker
|
- name: Run Engine tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Engine tests in docker with file store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with sqlite store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
- name: Run Peer tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
8
.github/workflows/golang-test-windows.yml
vendored
8
.github/workflows/golang-test-windows.yml
vendored
@@ -23,13 +23,13 @@ jobs:
|
|||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.21.x"
|
go-version: "1.20.x"
|
||||||
|
|
||||||
- name: Download wintun
|
- name: Download wintun
|
||||||
uses: carlosperate/download-file-action@v2
|
uses: carlosperate/download-file-action@v2
|
||||||
id: download-wintun
|
id: download-wintun
|
||||||
with:
|
with:
|
||||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
|
||||||
file-name: wintun.zip
|
file-name: wintun.zip
|
||||||
location: ${{ env.downloadPath }}
|
location: ${{ env.downloadPath }}
|
||||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||||
@@ -39,9 +39,7 @@ jobs:
|
|||||||
|
|
||||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||||
|
|
||||||
- run: choco install -y sysinternals --ignore-checksums
|
- run: choco install -y sysinternals
|
||||||
- run: choco install -y mingw
|
|
||||||
|
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||||
|
|
||||||
|
|||||||
41
.github/workflows/golangci-lint.yml
vendored
41
.github/workflows/golangci-lint.yml
vendored
@@ -1,48 +1,21 @@
|
|||||||
name: golangci-lint
|
name: golangci-lint
|
||||||
on: [pull_request]
|
on: [pull_request]
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: read
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
codespell:
|
golangci:
|
||||||
name: codespell
|
name: lint
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- uses: actions/checkout@v2
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: codespell
|
|
||||||
uses: codespell-project/actions-codespell@v2
|
|
||||||
with:
|
|
||||||
ignore_words_list: erro,clienta
|
|
||||||
skip: go.mod,go.sum
|
|
||||||
only_warn: 1
|
|
||||||
golangci:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
|
||||||
name: lint
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
timeout-minutes: 15
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21.x"
|
go-version: "1.20.x"
|
||||||
cache: false
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v2
|
||||||
with:
|
with:
|
||||||
version: latest
|
args: --timeout=6m
|
||||||
args: --timeout=12m
|
|
||||||
36
.github/workflows/install-script-test.yml
vendored
36
.github/workflows/install-script-test.yml
vendored
@@ -1,36 +0,0 @@
|
|||||||
name: Test installation
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- "release_files/install.sh"
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
jobs:
|
|
||||||
test-install-script:
|
|
||||||
strategy:
|
|
||||||
max-parallel: 2
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest]
|
|
||||||
skip_ui_mode: [true, false]
|
|
||||||
install_binary: [true, false]
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: run install script
|
|
||||||
env:
|
|
||||||
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
|
|
||||||
USE_BIN_INSTALL: ${{ matrix.install_binary }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
|
|
||||||
run: |
|
|
||||||
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
|
|
||||||
cat release_files/install.sh | sh -x
|
|
||||||
|
|
||||||
- name: check cli binary
|
|
||||||
run: command -v netbird
|
|
||||||
60
.github/workflows/install-test-darwin.yml
vendored
Normal file
60
.github/workflows/install-test-darwin.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
name: Test installation Darwin
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "release_files/install.sh"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
jobs:
|
||||||
|
install-cli-only:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Rename brew package
|
||||||
|
if: ${{ matrix.check_bin_install }}
|
||||||
|
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||||
|
|
||||||
|
- name: Run install script
|
||||||
|
run: |
|
||||||
|
sh ./release_files/install.sh
|
||||||
|
env:
|
||||||
|
SKIP_UI_APP: true
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
if ! command -v netbird &> /dev/null; then
|
||||||
|
echo "Error: netbird is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
install-all:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Rename brew package
|
||||||
|
if: ${{ matrix.check_bin_install }}
|
||||||
|
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||||
|
|
||||||
|
- name: Run install script
|
||||||
|
run: |
|
||||||
|
sh ./release_files/install.sh
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
if ! command -v netbird &> /dev/null; then
|
||||||
|
echo "Error: netbird is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
||||||
|
echo "Error: NetBird UI is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
38
.github/workflows/install-test-linux.yml
vendored
Normal file
38
.github/workflows/install-test-linux.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: Test installation Linux
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "release_files/install.sh"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
jobs:
|
||||||
|
install-cli-only:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
check_bin_install: [true, false]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Rename apt package
|
||||||
|
if: ${{ matrix.check_bin_install }}
|
||||||
|
run: |
|
||||||
|
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
||||||
|
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
||||||
|
|
||||||
|
- name: Run install script
|
||||||
|
run: |
|
||||||
|
sh ./release_files/install.sh
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
if ! command -v netbird &> /dev/null; then
|
||||||
|
echo "Error: netbird is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
98
.github/workflows/release.yml
vendored
98
.github/workflows/release.yml
vendored
@@ -7,20 +7,9 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'go.mod'
|
|
||||||
- 'go.sum'
|
|
||||||
- '.goreleaser.yml'
|
|
||||||
- '.goreleaser_ui.yaml'
|
|
||||||
- '.goreleaser_ui_darwin.yaml'
|
|
||||||
- '.github/workflows/release.yml'
|
|
||||||
- 'release_files/**'
|
|
||||||
- '**/Dockerfile'
|
|
||||||
- '**/Dockerfile.*'
|
|
||||||
- 'client/ui/**'
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.11"
|
SIGN_PIPE_VER: "v0.0.6"
|
||||||
GORELEASER_VER: "v1.14.1"
|
GORELEASER_VER: "v1.14.1"
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -30,32 +19,25 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
|
||||||
flags: ""
|
|
||||||
steps:
|
steps:
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
|
||||||
-
|
-
|
||||||
name: Checkout
|
name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21"
|
go-version: "1.20"
|
||||||
cache: false
|
|
||||||
-
|
-
|
||||||
name: Cache Go modules
|
name: Cache Go modules
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v1
|
||||||
with:
|
with:
|
||||||
path: |
|
path: ~/go/pkg/mod
|
||||||
~/go/pkg/mod
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
~/.cache/go-build
|
|
||||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-releaser-
|
${{ runner.os }}-go-
|
||||||
-
|
-
|
||||||
name: Install modules
|
name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
@@ -64,10 +46,10 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
-
|
-
|
||||||
name: Set up QEMU
|
name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v2
|
uses: docker/setup-qemu-action@v1
|
||||||
-
|
-
|
||||||
name: Set up Docker Buildx
|
name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v1
|
||||||
-
|
-
|
||||||
name: Login to Docker hub
|
name: Login to Docker hub
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
@@ -90,10 +72,10 @@ jobs:
|
|||||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --rm-dist ${{ env.flags }}
|
args: release --rm-dist
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
@@ -101,7 +83,7 @@ jobs:
|
|||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
-
|
-
|
||||||
name: upload non tags for debug purposes
|
name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v2
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -110,27 +92,22 @@ jobs:
|
|||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21"
|
go-version: "1.20"
|
||||||
cache: false
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v1
|
||||||
with:
|
with:
|
||||||
path: |
|
path: ~/go/pkg/mod
|
||||||
~/go/pkg/mod
|
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||||
~/.cache/go-build
|
|
||||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-ui-go-releaser-
|
${{ runner.os }}-ui-go-
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
@@ -139,23 +116,23 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
||||||
- name: Install rsrc
|
- name: Install rsrc
|
||||||
run: go install github.com/akavel/rsrc@v0.10.2
|
run: go install github.com/akavel/rsrc@v0.10.2
|
||||||
- name: Generate windows rsrc
|
- name: Generate windows rsrc
|
||||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
args: release --config .goreleaser_ui.yaml --rm-dist
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v2
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -164,44 +141,39 @@ jobs:
|
|||||||
release_ui_darwin:
|
release_ui_darwin:
|
||||||
runs-on: macos-11
|
runs-on: macos-11
|
||||||
steps:
|
steps:
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
|
||||||
-
|
-
|
||||||
name: Checkout
|
name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: "1.21"
|
go-version: "1.20"
|
||||||
cache: false
|
|
||||||
-
|
-
|
||||||
name: Cache Go modules
|
name: Cache Go modules
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v1
|
||||||
with:
|
with:
|
||||||
path: |
|
path: ~/go/pkg/mod
|
||||||
~/go/pkg/mod
|
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||||
~/.cache/go-build
|
|
||||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-ui-go-releaser-darwin-
|
${{ runner.os }}-ui-go-
|
||||||
-
|
-
|
||||||
name: Install modules
|
name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
-
|
-
|
||||||
name: upload non tags for debug purposes
|
name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v2
|
||||||
with:
|
with:
|
||||||
name: release-ui-darwin
|
name: release-ui-darwin
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -223,7 +195,7 @@ jobs:
|
|||||||
|
|
||||||
trigger_darwin_signer:
|
trigger_darwin_signer:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [release,release_ui_darwin]
|
needs: release_ui_darwin
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger Darwin App binaries sign pipeline
|
- name: Trigger Darwin App binaries sign pipeline
|
||||||
|
|||||||
22
.github/workflows/sync-main.yml
vendored
22
.github/workflows/sync-main.yml
vendored
@@ -1,22 +0,0 @@
|
|||||||
name: sync main
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
trigger_sync_main:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Trigger main branch sync
|
|
||||||
uses: benc-uk/workflow-dispatch@v1
|
|
||||||
with:
|
|
||||||
workflow: sync-main.yml
|
|
||||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
|
||||||
23
.github/workflows/sync-tag.yml
vendored
23
.github/workflows/sync-tag.yml
vendored
@@ -1,23 +0,0 @@
|
|||||||
name: sync tag
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
trigger_sync_tag:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Trigger release tag sync
|
|
||||||
uses: benc-uk/workflow-dispatch@v1
|
|
||||||
with:
|
|
||||||
workflow: sync-tag.yml
|
|
||||||
ref: main
|
|
||||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
|
||||||
99
.github/workflows/test-docker-compose-linux.yml
vendored
Normal file
99
.github/workflows/test-docker-compose-linux.yml
vendored
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
name: Test Docker Compose Linux
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install jq
|
||||||
|
run: sudo apt-get install -y jq
|
||||||
|
|
||||||
|
- name: Install curl
|
||||||
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: "1.20.x"
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: cp setup.env
|
||||||
|
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||||
|
|
||||||
|
- name: run configure
|
||||||
|
working-directory: infrastructure_files
|
||||||
|
run: bash -x configure.sh
|
||||||
|
env:
|
||||||
|
CI_NETBIRD_DOMAIN: localhost
|
||||||
|
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
||||||
|
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
||||||
|
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
||||||
|
CI_NETBIRD_USE_AUTH0: true
|
||||||
|
|
||||||
|
- name: check values
|
||||||
|
working-directory: infrastructure_files
|
||||||
|
env:
|
||||||
|
CI_NETBIRD_DOMAIN: localhost
|
||||||
|
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
||||||
|
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
||||||
|
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
||||||
|
CI_NETBIRD_USE_AUTH0: true
|
||||||
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
|
CI_NETBIRD_AUTH_AUTHORITY: https://example.eu.auth0.com/
|
||||||
|
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
||||||
|
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||||
|
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||||
|
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||||
|
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||||
|
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||||
|
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
||||||
|
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
||||||
|
|
||||||
|
run: |
|
||||||
|
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
|
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||||
|
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||||
|
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
|
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
|
||||||
|
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||||
|
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||||
|
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||||
|
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||||
|
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||||
|
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||||
|
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||||
|
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||||
|
grep UseIDToken management.json | grep false
|
||||||
|
|
||||||
|
- name: run docker compose up
|
||||||
|
working-directory: infrastructure_files
|
||||||
|
run: |
|
||||||
|
docker-compose up -d
|
||||||
|
sleep 5
|
||||||
|
docker-compose ps
|
||||||
|
docker-compose logs --tail=20
|
||||||
|
|
||||||
|
- name: test running containers
|
||||||
|
run: |
|
||||||
|
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
||||||
|
test $count -eq 4
|
||||||
|
working-directory: infrastructure_files
|
||||||
188
.github/workflows/test-infrastructure-files.yml
vendored
188
.github/workflows/test-infrastructure-files.yml
vendored
@@ -1,188 +0,0 @@
|
|||||||
name: Test Infrastructure files
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- 'infrastructure_files/**'
|
|
||||||
- '.github/workflows/test-infrastructure-files.yml'
|
|
||||||
- 'management/cmd/**'
|
|
||||||
- 'signal/cmd/**'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test-docker-compose:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Install jq
|
|
||||||
run: sudo apt-get install -y jq
|
|
||||||
|
|
||||||
- name: Install curl
|
|
||||||
run: sudo apt-get install -y curl
|
|
||||||
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: "1.21.x"
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache@v3
|
|
||||||
with:
|
|
||||||
path: ~/go/pkg/mod
|
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: cp setup.env
|
|
||||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
|
||||||
|
|
||||||
- name: run configure
|
|
||||||
working-directory: infrastructure_files
|
|
||||||
run: bash -x configure.sh
|
|
||||||
env:
|
|
||||||
CI_NETBIRD_DOMAIN: localhost
|
|
||||||
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
|
||||||
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
|
||||||
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
|
||||||
CI_NETBIRD_USE_AUTH0: true
|
|
||||||
CI_NETBIRD_MGMT_IDP: "none"
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
|
||||||
|
|
||||||
- name: check values
|
|
||||||
working-directory: infrastructure_files/artifacts
|
|
||||||
env:
|
|
||||||
CI_NETBIRD_DOMAIN: localhost
|
|
||||||
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
|
||||||
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
|
|
||||||
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
|
|
||||||
CI_NETBIRD_USE_AUTH0: true
|
|
||||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
|
||||||
CI_NETBIRD_AUTH_AUTHORITY: https://example.eu.auth0.com/
|
|
||||||
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
|
||||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
|
||||||
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
|
|
||||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
|
||||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
|
||||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
|
||||||
CI_NETBIRD_MGMT_IDP: "none"
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
|
||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
|
||||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
|
||||||
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
|
||||||
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
|
||||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
|
||||||
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
|
|
||||||
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
|
||||||
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
|
|
||||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
|
||||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
|
||||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
|
||||||
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
|
|
||||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
|
||||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
|
||||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
|
||||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
|
||||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
|
||||||
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
|
|
||||||
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
|
|
||||||
grep UseIDToken management.json | grep false
|
|
||||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
|
||||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
|
||||||
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
|
||||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
|
||||||
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
|
||||||
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
|
||||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
|
||||||
|
|
||||||
- name: Install modules
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: Build management binary
|
|
||||||
working-directory: management
|
|
||||||
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
|
||||||
|
|
||||||
- name: Build management docker image
|
|
||||||
working-directory: management
|
|
||||||
run: |
|
|
||||||
docker build -t netbirdio/management:latest .
|
|
||||||
|
|
||||||
- name: Build signal binary
|
|
||||||
working-directory: signal
|
|
||||||
run: CGO_ENABLED=0 go build -o netbird-signal main.go
|
|
||||||
|
|
||||||
- name: Build signal docker image
|
|
||||||
working-directory: signal
|
|
||||||
run: |
|
|
||||||
docker build -t netbirdio/signal:latest .
|
|
||||||
|
|
||||||
- name: run docker compose up
|
|
||||||
working-directory: infrastructure_files/artifacts
|
|
||||||
run: |
|
|
||||||
docker-compose up -d
|
|
||||||
sleep 5
|
|
||||||
docker-compose ps
|
|
||||||
docker-compose logs --tail=20
|
|
||||||
|
|
||||||
- name: test running containers
|
|
||||||
run: |
|
|
||||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
|
|
||||||
test $count -eq 4
|
|
||||||
working-directory: infrastructure_files/artifacts
|
|
||||||
|
|
||||||
test-getting-started-script:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Install jq
|
|
||||||
run: sudo apt-get install -y jq
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: run script
|
|
||||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
|
||||||
|
|
||||||
- name: test Caddy file gen
|
|
||||||
run: test -f Caddyfile
|
|
||||||
- name: test docker-compose file gen
|
|
||||||
run: test -f docker-compose.yml
|
|
||||||
- name: test management.json file gen
|
|
||||||
run: test -f management.json
|
|
||||||
- name: test turnserver.conf file gen
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
test -f turnserver.conf
|
|
||||||
grep external-ip turnserver.conf
|
|
||||||
- name: test zitadel.env file gen
|
|
||||||
run: test -f zitadel.env
|
|
||||||
- name: test dashboard.env file gen
|
|
||||||
run: test -f dashboard.env
|
|
||||||
22
.github/workflows/update-docs.yml
vendored
22
.github/workflows/update-docs.yml
vendored
@@ -1,22 +0,0 @@
|
|||||||
name: update docs
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
paths:
|
|
||||||
- 'management/server/http/api/openapi.yml'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
trigger_docs_api_update:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
|
||||||
steps:
|
|
||||||
- name: Trigger API pages generation
|
|
||||||
uses: benc-uk/workflow-dispatch@v1
|
|
||||||
with:
|
|
||||||
workflow: generate api pages
|
|
||||||
repo: netbirdio/docs
|
|
||||||
ref: "refs/heads/main"
|
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
|
||||||
22
.gitignore
vendored
22
.gitignore
vendored
@@ -6,27 +6,9 @@ bin/
|
|||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
http-cmds.sh
|
http-cmds.sh
|
||||||
setup.env
|
infrastructure_files/management.json
|
||||||
infrastructure_files/**/Caddyfile
|
infrastructure_files/docker-compose.yml
|
||||||
infrastructure_files/**/dashboard.env
|
|
||||||
infrastructure_files/**/zitadel.env
|
|
||||||
infrastructure_files/**/management.json
|
|
||||||
infrastructure_files/**/management-*.json
|
|
||||||
infrastructure_files/**/docker-compose.yml
|
|
||||||
infrastructure_files/**/openid-configuration.json
|
|
||||||
infrastructure_files/**/turnserver.conf
|
|
||||||
infrastructure_files/**/management.json.bkp.**
|
|
||||||
infrastructure_files/**/management-*.json.bkp.**
|
|
||||||
infrastructure_files/**/docker-compose.yml.bkp.**
|
|
||||||
infrastructure_files/**/openid-configuration.json.bkp.**
|
|
||||||
infrastructure_files/**/turnserver.conf.bkp.**
|
|
||||||
management/management
|
|
||||||
client/client
|
|
||||||
client/client.exe
|
|
||||||
*.syso
|
*.syso
|
||||||
client/.distfiles/
|
client/.distfiles/
|
||||||
infrastructure_files/setup.env
|
infrastructure_files/setup.env
|
||||||
infrastructure_files/setup-*.env
|
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
|
||||||
*.db
|
|
||||||
123
.golangci.yaml
123
.golangci.yaml
@@ -1,123 +0,0 @@
|
|||||||
run:
|
|
||||||
# Timeout for analysis, e.g. 30s, 5m.
|
|
||||||
# Default: 1m
|
|
||||||
timeout: 6m
|
|
||||||
|
|
||||||
# This file contains only configs which differ from defaults.
|
|
||||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
|
||||||
linters-settings:
|
|
||||||
errcheck:
|
|
||||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
|
||||||
# Such cases aren't reported by default.
|
|
||||||
# Default: false
|
|
||||||
check-type-assertions: false
|
|
||||||
|
|
||||||
gosec:
|
|
||||||
includes:
|
|
||||||
- G101 # Look for hard coded credentials
|
|
||||||
#- G102 # Bind to all interfaces
|
|
||||||
- G103 # Audit the use of unsafe block
|
|
||||||
- G104 # Audit errors not checked
|
|
||||||
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
|
|
||||||
#- G107 # Url provided to HTTP request as taint input
|
|
||||||
- G108 # Profiling endpoint automatically exposed on /debug/pprof
|
|
||||||
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
|
|
||||||
- G110 # Potential DoS vulnerability via decompression bomb
|
|
||||||
- G111 # Potential directory traversal
|
|
||||||
#- G112 # Potential slowloris attack
|
|
||||||
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
|
|
||||||
#- G114 # Use of net/http serve function that has no support for setting timeouts
|
|
||||||
- G201 # SQL query construction using format string
|
|
||||||
- G202 # SQL query construction using string concatenation
|
|
||||||
- G203 # Use of unescaped data in HTML templates
|
|
||||||
#- G204 # Audit use of command execution
|
|
||||||
- G301 # Poor file permissions used when creating a directory
|
|
||||||
- G302 # Poor file permissions used with chmod
|
|
||||||
- G303 # Creating tempfile using a predictable path
|
|
||||||
- G304 # File path provided as taint input
|
|
||||||
- G305 # File traversal when extracting zip/tar archive
|
|
||||||
- G306 # Poor file permissions used when writing to a new file
|
|
||||||
- G307 # Poor file permissions used when creating a file with os.Create
|
|
||||||
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
|
|
||||||
#- G402 # Look for bad TLS connection settings
|
|
||||||
- G403 # Ensure minimum RSA key length of 2048 bits
|
|
||||||
#- G404 # Insecure random number source (rand)
|
|
||||||
#- G501 # Import blocklist: crypto/md5
|
|
||||||
- G502 # Import blocklist: crypto/des
|
|
||||||
- G503 # Import blocklist: crypto/rc4
|
|
||||||
- G504 # Import blocklist: net/http/cgi
|
|
||||||
#- G505 # Import blocklist: crypto/sha1
|
|
||||||
- G601 # Implicit memory aliasing of items from a range statement
|
|
||||||
- G602 # Slice access out of bounds
|
|
||||||
|
|
||||||
gocritic:
|
|
||||||
disabled-checks:
|
|
||||||
- commentFormatting
|
|
||||||
- captLocal
|
|
||||||
- deprecatedComment
|
|
||||||
|
|
||||||
govet:
|
|
||||||
# Enable all analyzers.
|
|
||||||
# Default: false
|
|
||||||
enable-all: false
|
|
||||||
enable:
|
|
||||||
- nilness
|
|
||||||
|
|
||||||
tenv:
|
|
||||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
|
||||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
|
||||||
# Default: false
|
|
||||||
all: true
|
|
||||||
|
|
||||||
linters:
|
|
||||||
disable-all: true
|
|
||||||
enable:
|
|
||||||
## enabled by default
|
|
||||||
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
|
||||||
- gosimple # specializes in simplifying a code
|
|
||||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
|
||||||
- ineffassign # detects when assignments to existing variables are not used
|
|
||||||
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
|
||||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
|
||||||
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
|
||||||
- unused # checks for unused constants, variables, functions and types
|
|
||||||
## disable by default but the have interesting results so lets add them
|
|
||||||
- bodyclose # checks whether HTTP response body is closed successfully
|
|
||||||
- dupword # dupword checks for duplicate words in the source code
|
|
||||||
- durationcheck # durationcheck checks for two durations multiplied together
|
|
||||||
- forbidigo # forbidigo forbids identifiers
|
|
||||||
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
|
||||||
- gosec # inspects source code for security problems
|
|
||||||
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
|
|
||||||
- misspell # misspess finds commonly misspelled English words in comments
|
|
||||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
|
||||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
|
||||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
|
||||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
|
||||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
|
||||||
- wastedassign # wastedassign finds wasted assignment statements
|
|
||||||
issues:
|
|
||||||
# Maximum count of issues with the same text.
|
|
||||||
# Set to 0 to disable.
|
|
||||||
# Default: 3
|
|
||||||
max-same-issues: 5
|
|
||||||
|
|
||||||
exclude-rules:
|
|
||||||
# allow fmt
|
|
||||||
- path: management/cmd/root\.go
|
|
||||||
linters: forbidigo
|
|
||||||
- path: signal/cmd/root\.go
|
|
||||||
linters: forbidigo
|
|
||||||
- path: sharedsock/filter\.go
|
|
||||||
linters:
|
|
||||||
- unused
|
|
||||||
- path: client/firewall/iptables/rule\.go
|
|
||||||
linters:
|
|
||||||
- unused
|
|
||||||
- path: test\.go
|
|
||||||
linters:
|
|
||||||
- mirror
|
|
||||||
- gosec
|
|
||||||
- path: mock\.go
|
|
||||||
linters:
|
|
||||||
- nilnil
|
|
||||||
@@ -12,7 +12,11 @@ builds:
|
|||||||
- arm
|
- arm
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
- mips
|
||||||
- 386
|
- 386
|
||||||
|
gomips:
|
||||||
|
- hardfloat
|
||||||
|
- softfloat
|
||||||
ignore:
|
ignore:
|
||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -26,26 +30,6 @@ builds:
|
|||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
- id: netbird-static
|
|
||||||
dir: client
|
|
||||||
binary: netbird
|
|
||||||
env: [CGO_ENABLED=0]
|
|
||||||
goos:
|
|
||||||
- linux
|
|
||||||
goarch:
|
|
||||||
- mips
|
|
||||||
- mipsle
|
|
||||||
- mips64
|
|
||||||
- mips64le
|
|
||||||
gomips:
|
|
||||||
- hardfloat
|
|
||||||
- softfloat
|
|
||||||
ldflags:
|
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
|
||||||
tags:
|
|
||||||
- load_wgnt_from_rsrc
|
|
||||||
|
|
||||||
- id: netbird-mgmt
|
- id: netbird-mgmt
|
||||||
dir: management
|
dir: management
|
||||||
env:
|
env:
|
||||||
@@ -83,7 +67,6 @@ builds:
|
|||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
- netbird-static
|
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
|
|
||||||
@@ -377,13 +360,3 @@ uploads:
|
|||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
method: PUT
|
method: PUT
|
||||||
|
|
||||||
checksum:
|
|
||||||
extra_files:
|
|
||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
|
||||||
- glob: ./release_files/install.sh
|
|
||||||
|
|
||||||
release:
|
|
||||||
extra_files:
|
|
||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
|
||||||
- glob: ./release_files/install.sh
|
|
||||||
@@ -11,8 +11,6 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
tags:
|
|
||||||
- legacy_appindicator
|
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
@@ -54,9 +52,12 @@ nfpms:
|
|||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-default.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
|
- libayatana-appindicator3-1
|
||||||
|
- libgtk-3-dev
|
||||||
|
- libappindicator3-dev
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
@@ -71,9 +72,12 @@ nfpms:
|
|||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-default.png
|
- src: client/ui/disconnected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
|
- libayatana-appindicator3-1
|
||||||
|
- libgtk-3-dev
|
||||||
|
- libappindicator3-dev
|
||||||
- netbird
|
- netbird
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
|
|||||||
- [Development setup](#development-setup)
|
- [Development setup](#development-setup)
|
||||||
- [Requirements](#requirements)
|
- [Requirements](#requirements)
|
||||||
- [Local NetBird setup](#local-netbird-setup)
|
- [Local NetBird setup](#local-netbird-setup)
|
||||||
- [Dev Container Support](#dev-container-support)
|
|
||||||
- [Build and start](#build-and-start)
|
- [Build and start](#build-and-start)
|
||||||
- [Test suite](#test-suite)
|
- [Test suite](#test-suite)
|
||||||
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
|
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
|
||||||
- [Other project repositories](#other-project-repositories)
|
- [Other project repositories](#other-project-repositories)
|
||||||
|
- [Checklist before submitting a new node](#checklist-before-submitting-a-new-node)
|
||||||
- [Contributor License Agreement](#contributor-license-agreement)
|
- [Contributor License Agreement](#contributor-license-agreement)
|
||||||
|
|
||||||
## Code of conduct
|
## Code of conduct
|
||||||
@@ -70,7 +70,7 @@ dependencies are installed. Here is a short guide on how that can be done.
|
|||||||
|
|
||||||
### Requirements
|
### Requirements
|
||||||
|
|
||||||
#### Go 1.21
|
#### Go 1.19
|
||||||
|
|
||||||
Follow the installation guide from https://go.dev/
|
Follow the installation guide from https://go.dev/
|
||||||
|
|
||||||
@@ -136,61 +136,18 @@ checked out and set up:
|
|||||||
go mod tidy
|
go mod tidy
|
||||||
```
|
```
|
||||||
|
|
||||||
### Dev Container Support
|
|
||||||
|
|
||||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
|
||||||
Dev containers provide a consistent and isolated development environment, making it easier for contributors to get started quickly. Follow the steps below to set up NetBird in a dev container.
|
|
||||||
|
|
||||||
#### 1. Prerequisites:
|
|
||||||
|
|
||||||
* Install Docker on your machine: [Docker Installation Guide](https://docs.docker.com/get-docker/)
|
|
||||||
* Install Visual Studio Code: [VS Code Installation Guide](https://code.visualstudio.com/download)
|
|
||||||
* If you prefer JetBrains Goland please follow this [manual](https://www.jetbrains.com/help/go/connect-to-devcontainer.html)
|
|
||||||
|
|
||||||
#### 2. Clone the Repository:
|
|
||||||
|
|
||||||
Clone the repository following previous [Local NetBird setup](#local-netbird-setup).
|
|
||||||
|
|
||||||
#### 3. Open in project in IDE of your choice:
|
|
||||||
|
|
||||||
**VScode**:
|
|
||||||
|
|
||||||
Open the project folder in Visual Studio Code:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
code .
|
|
||||||
```
|
|
||||||
|
|
||||||
When you open the project in VS Code, it will detect the presence of a dev container configuration.
|
|
||||||
Click on the green "Reopen in Container" button in the bottom-right corner of VS Code.
|
|
||||||
|
|
||||||
**Goland**:
|
|
||||||
|
|
||||||
Open GoLand and select `"File" > "Open"` to open the NetBird project folder.
|
|
||||||
GoLand will detect the dev container configuration and prompt you to open the project in the container. Accept the prompt.
|
|
||||||
|
|
||||||
#### 4. Wait for the Container to Build:
|
|
||||||
|
|
||||||
VsCode or GoLand will use the specified Docker image to build the dev container. This might take some time, depending on your internet connection.
|
|
||||||
|
|
||||||
#### 6. Development:
|
|
||||||
|
|
||||||
Once the container is built, you can start developing within the dev container. All the necessary dependencies and configurations are set up within the container.
|
|
||||||
|
|
||||||
|
|
||||||
### Build and start
|
### Build and start
|
||||||
#### Client
|
#### Client
|
||||||
|
|
||||||
|
> Windows clients have a Wireguard driver requirement. We provide a bash script that can be executed in WLS 2 with docker support [wireguard_nt.sh](/client/wireguard_nt.sh).
|
||||||
|
|
||||||
To start NetBird, execute:
|
To start NetBird, execute:
|
||||||
```
|
```
|
||||||
cd client
|
cd client
|
||||||
CGO_ENABLED=0 go build .
|
# bash wireguard_nt.sh # if windows
|
||||||
|
go build .
|
||||||
```
|
```
|
||||||
|
|
||||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`.
|
|
||||||
|
|
||||||
> To test the client GUI application on Windows machines with RDP or vituralized environments (e.g. virtualbox or cloud), you need to download and extract the opengl32.dll from https://fdossena.com/?p=mesa/index.frag next to the built application.
|
|
||||||
|
|
||||||
To start NetBird the client in the foreground:
|
To start NetBird the client in the foreground:
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -228,42 +185,6 @@ To start NetBird the management service:
|
|||||||
./management management --log-level debug --log-file console --config ./management.json
|
./management management --log-level debug --log-file console --config ./management.json
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Windows Netbird Installer
|
|
||||||
Create dist directory
|
|
||||||
```shell
|
|
||||||
mkdir -p dist/netbird_windows_amd64
|
|
||||||
```
|
|
||||||
|
|
||||||
UI client
|
|
||||||
```shell
|
|
||||||
CC=x86_64-w64-mingw32-gcc CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -o netbird-ui.exe -ldflags "-s -w -H windowsgui" ./client/ui
|
|
||||||
mv netbird-ui.exe ./dist/netbird_windows_amd64/
|
|
||||||
```
|
|
||||||
|
|
||||||
Client
|
|
||||||
```shell
|
|
||||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o netbird.exe ./client/
|
|
||||||
mv netbird.exe ./dist/netbird_windows_amd64/
|
|
||||||
```
|
|
||||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to `./dist/netbird_windows_amd64/`.
|
|
||||||
|
|
||||||
NSIS compiler
|
|
||||||
- [Windows-nsis]( https://nsis.sourceforge.io/Download)
|
|
||||||
- [MacOS-makensis](https://formulae.brew.sh/formula/makensis#default)
|
|
||||||
- [Linux-makensis](https://manpages.ubuntu.com/manpages/trusty/man1/makensis.1.html)
|
|
||||||
|
|
||||||
NSIS Plugins. Download and move them to the NSIS plugins folder.
|
|
||||||
- [EnVar](https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip)
|
|
||||||
- [ShellExecAsUser](https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z)
|
|
||||||
|
|
||||||
Windows Installer
|
|
||||||
```shell
|
|
||||||
export APPVER=0.0.0.1
|
|
||||||
makensis -V4 client/installer.nsis
|
|
||||||
```
|
|
||||||
|
|
||||||
The installer `netbird-installer.exe` will be created in root directory.
|
|
||||||
|
|
||||||
### Test suite
|
### Test suite
|
||||||
|
|
||||||
The tests can be started via:
|
The tests can be started via:
|
||||||
|
|||||||
95
README.md
95
README.md
@@ -1,6 +1,6 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
|
<strong>:hatching_chick: New Release! Peer expiration.</strong>
|
||||||
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
|
<a href="https://github.com/netbirdio/netbird/releases">
|
||||||
Learn more
|
Learn more
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
@@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>
|
<strong>
|
||||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
|
||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
@@ -36,62 +36,47 @@
|
|||||||
|
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
**NetBird is an open-source VPN management platform built on top of WireGuard® making it easy to create secure private networks for your organization or home.**
|
||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|
||||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience.
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
||||||
|
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
||||||
|
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
||||||
|
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
||||||
|
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
||||||
|
- \[x] Multiuser support - sharing network between multiple users.
|
||||||
|
- \[x] SSO and MFA support.
|
||||||
|
- \[x] Multicloud and hybrid-cloud support.
|
||||||
|
- \[x] Kernel WireGuard usage when possible.
|
||||||
|
- \[x] Access Controls - groups & rules.
|
||||||
|
- \[x] Remote SSH access without managing SSH keys.
|
||||||
|
- \[x] Network Routes.
|
||||||
|
- \[x] Private DNS.
|
||||||
|
- \[x] Network Activity Monitoring.
|
||||||
|
|
||||||
|
**Coming soon:**
|
||||||
|
- \[ ] Mobile clients.
|
||||||
|
|
||||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
||||||
|
|
||||||
### Key features
|
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||||
|
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||||
|
|
||||||
| Connectivity | Management | Automation | Platforms |
|
### Start using NetBird
|
||||||
|---------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
||||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
||||||
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
||||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
|
||||||
| <ul><li> - \[x] Post-quantum-secure connection through [Rosenpass](https://rosenpass.eu) </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
|
||||||
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
|
||||||
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
|
||||||
|
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
|
||||||
|
|
||||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
|
||||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
|
||||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
|
||||||
- Add more machines.
|
|
||||||
|
|
||||||
### Quickstart with self-hosted NetBird
|
|
||||||
|
|
||||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
|
||||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
|
||||||
|
|
||||||
**Infrastructure requirements:**
|
|
||||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
|
||||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
|
|
||||||
- **Public domain** name pointing to the VM.
|
|
||||||
|
|
||||||
**Software requirements:**
|
|
||||||
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
|
||||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
|
||||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
|
||||||
- [curl](https://curl.se/) installed.
|
|
||||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
|
||||||
|
|
||||||
**Steps**
|
|
||||||
- Download and run the installation script:
|
|
||||||
```bash
|
|
||||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
|
|
||||||
```
|
|
||||||
- Once finished, you can manage the resources via `docker-compose`
|
|
||||||
|
|
||||||
### A bit on NetBird internals
|
### A bit on NetBird internals
|
||||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||||
@@ -103,18 +88,18 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
|||||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||||
|
|
||||||
<p float="left" align="middle">
|
<p float="left" align="middle">
|
||||||
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
||||||
|
|
||||||
|
### Roadmap
|
||||||
|
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||||
|
|
||||||
### Community projects
|
### Community projects
|
||||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||||
|
|
||||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
|
||||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
|
||||||
|
|
||||||
### Support acknowledgement
|
### Support acknowledgement
|
||||||
|
|
||||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||||
@@ -122,7 +107,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
|||||||

|

|
||||||
|
|
||||||
### Testimonials
|
### Testimonials
|
||||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
||||||
|
|
||||||
### Legal
|
### Legal
|
||||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
package base62
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
||||||
base = uint32(len(alphabet))
|
|
||||||
)
|
|
||||||
|
|
||||||
// Encode encodes a uint32 value to a base62 string.
|
|
||||||
func Encode(num uint32) string {
|
|
||||||
if num == 0 {
|
|
||||||
return string(alphabet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var encoded strings.Builder
|
|
||||||
|
|
||||||
for num > 0 {
|
|
||||||
remainder := num % base
|
|
||||||
encoded.WriteByte(alphabet[remainder])
|
|
||||||
num /= base
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse the encoded string
|
|
||||||
encodedString := encoded.String()
|
|
||||||
reversed := reverse(encodedString)
|
|
||||||
return reversed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode decodes a base62 string to a uint32 value.
|
|
||||||
func Decode(encoded string) (uint32, error) {
|
|
||||||
var decoded uint32
|
|
||||||
strLen := len(encoded)
|
|
||||||
|
|
||||||
for i, char := range encoded {
|
|
||||||
index := strings.IndexRune(alphabet, char)
|
|
||||||
if index < 0 {
|
|
||||||
return 0, fmt.Errorf("invalid character: %c", char)
|
|
||||||
}
|
|
||||||
|
|
||||||
decoded += uint32(index) * uint32(math.Pow(float64(base), float64(strLen-i-1)))
|
|
||||||
}
|
|
||||||
|
|
||||||
return decoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse a string.
|
|
||||||
func reverse(s string) string {
|
|
||||||
runes := []rune(s)
|
|
||||||
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
|
|
||||||
runes[i], runes[j] = runes[j], runes[i]
|
|
||||||
}
|
|
||||||
return string(runes)
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package base62
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEncodeDecode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
num uint32
|
|
||||||
}{
|
|
||||||
{0},
|
|
||||||
{1},
|
|
||||||
{42},
|
|
||||||
{12345},
|
|
||||||
{99999},
|
|
||||||
{123456789},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
encoded := Encode(tt.num)
|
|
||||||
decoded, err := Decode(encoded)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Decode error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if decoded != tt.num {
|
|
||||||
t.Errorf("Decode(%v) = %v, want %v", encoded, decoded, tt.num)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
FROM alpine:3.18.5
|
FROM gcr.io/distroless/base:debug
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
|
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
|
||||||
|
SHELL ["/busybox/sh","-c"]
|
||||||
|
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
|
||||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||||
COPY netbird /go/bin/netbird
|
COPY netbird /go/bin/netbird
|
||||||
@@ -7,8 +7,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -31,47 +29,38 @@ type IFaceDiscover interface {
|
|||||||
stdnet.ExternalIFaceDiscover
|
stdnet.ExternalIFaceDiscover
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetworkChangeListener export internal NetworkChangeListener for mobile
|
|
||||||
type NetworkChangeListener interface {
|
|
||||||
listener.NetworkChangeListener
|
|
||||||
}
|
|
||||||
|
|
||||||
// DnsReadyListener export internal dns ReadyListener for mobile
|
|
||||||
type DnsReadyListener interface {
|
|
||||||
dns.ReadyListener
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client struct manage the life circle of background service
|
// Client struct manage the life circle of background service
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cfgFile string
|
cfgFile string
|
||||||
tunAdapter iface.TunAdapter
|
tunAdapter iface.TunAdapter
|
||||||
iFaceDiscover IFaceDiscover
|
iFaceDiscover IFaceDiscover
|
||||||
recorder *peer.Status
|
recorder *peer.Status
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
deviceName string
|
deviceName string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
|
||||||
|
lvl, _ := log.ParseLevel("trace")
|
||||||
|
log.SetLevel(lvl)
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
recorder: peer.NewRecorder(""),
|
recorder: peer.NewRecorder(""),
|
||||||
ctxCancelLock: &sync.Mutex{},
|
ctxCancelLock: &sync.Mutex{},
|
||||||
networkChangeListener: networkChangeListener,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
func (c *Client) Run(urlOpener URLOpener) error {
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -96,31 +85,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
|
||||||
}
|
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
|
||||||
// In this case make no sense handle registration steps.
|
|
||||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
|
||||||
ConfigPath: c.cfgFile,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
|
||||||
|
|
||||||
var ctx context.Context
|
|
||||||
//nolint
|
|
||||||
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
|
||||||
c.ctxCancelLock.Lock()
|
|
||||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
|
||||||
defer c.ctxCancel()
|
|
||||||
c.ctxCancelLock.Unlock()
|
|
||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
|
||||||
ctx = internal.CtxInitState(ctx)
|
|
||||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -134,11 +99,6 @@ func (c *Client) Stop() {
|
|||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
|
||||||
func (c *Client) SetTraceLogLevel() {
|
|
||||||
log.SetLevel(log.TraceLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeersList return with the list of the PeerInfos
|
// PeersList return with the list of the PeerInfos
|
||||||
func (c *Client) PeersList() *PeerInfoArray {
|
func (c *Client) PeersList() *PeerInfoArray {
|
||||||
|
|
||||||
@@ -150,23 +110,14 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
p.ConnStatus.String(),
|
||||||
|
p.Direct,
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
|
||||||
dnsServer, err := dns.GetServerDns()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetConnectionListener set the network connection listener
|
// SetConnectionListener set the network connection listener
|
||||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||||
c.recorder.SetConnectionListener(listener)
|
c.recorder.SetConnectionListener(listener)
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package android
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// DNSList is a wrapper of []string
|
|
||||||
type DNSList struct {
|
|
||||||
items []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add new DNS address to the collection
|
|
||||||
func (array *DNSList) Add(s string) {
|
|
||||||
array.items = append(array.items, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get return an element of the collection
|
|
||||||
func (array *DNSList) Get(i int) (string, error) {
|
|
||||||
if i >= len(array.items) || i < 0 {
|
|
||||||
return "", fmt.Errorf("out of range")
|
|
||||||
}
|
|
||||||
return array.items[i], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size return with the size of the collection
|
|
||||||
func (array *DNSList) Size() int {
|
|
||||||
return len(array.items)
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
package android
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestDNSList_Get(t *testing.T) {
|
|
||||||
l := DNSList{
|
|
||||||
items: make([]string, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := l.Get(0)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("invalid error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = l.Get(-1)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expected error but got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = l.Get(1)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expected error but got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package android
|
|
||||||
|
|
||||||
import _ "golang.org/x/mobile/bind"
|
|
||||||
|
|
||||||
// to keep our CI/CD that checks go.mod and go.sum files happy, we need to import the package above
|
|
||||||
@@ -6,14 +6,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSOListener is async listener for mobile framework
|
// SSOListener is async listener for mobile framework
|
||||||
@@ -84,21 +85,11 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
supportsSSO := true
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
supportsSSO = false
|
||||||
s, ok := gstatus.FromError(err)
|
err = nil
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -192,23 +183,35 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
|
|
||||||
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type PeerInfo struct {
|
|||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
|
Direct bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
|||||||
p.SetManagementURL(exampleString)
|
p.SetManagementURL(exampleString)
|
||||||
resp, err = p.GetManagementURL()
|
resp, err = p.GetManagementURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read management url: %s", err)
|
t.Fatalf("failed to read managmenet url: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != exampleString {
|
if resp != exampleString {
|
||||||
t.Errorf("unexpected management url: %s", resp)
|
t.Errorf("unexpected managemenet url: %s", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.SetPreSharedKey(exampleString)
|
p.SetPreSharedKey(exampleString)
|
||||||
@@ -102,11 +102,11 @@ func TestPreferences_Commit(t *testing.T) {
|
|||||||
|
|
||||||
resp, err = p.GetManagementURL()
|
resp, err = p.GetManagementURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read management url: %s", err)
|
t.Fatalf("failed to read managmenet url: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != exampleURL {
|
if resp != exampleURL {
|
||||||
t.Errorf("unexpected management url: %s", resp)
|
t.Errorf("unexpected managemenet url: %s", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err = p.GetPreSharedKey()
|
resp, err = p.GetPreSharedKey()
|
||||||
|
|||||||
@@ -3,20 +3,20 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
@@ -51,7 +51,7 @@ var loginCmd = &cobra.Command{
|
|||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
}
|
}
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if preSharedKey != "" {
|
||||||
ic.PreSharedKey = &preSharedKey
|
ic.PreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ var loginCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -81,11 +81,9 @@ var loginCmd = &cobra.Command{
|
|||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: setupKey,
|
SetupKey: setupKey,
|
||||||
PreSharedKey: preSharedKey,
|
PreSharedKey: preSharedKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
|
||||||
Hostname: hostName,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
@@ -115,7 +113,7 @@ var loginCmd = &cobra.Command{
|
|||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||||
}
|
}
|
||||||
@@ -151,21 +149,13 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastError error
|
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
err = WithBackOff(func() error {
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
lastError = err
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
if lastError != nil {
|
|
||||||
return fmt.Errorf("login failed: %v", lastError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -173,24 +163,40 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
mgmtURL := managementURL
|
||||||
|
if mgmtURL == "" {
|
||||||
|
mgmtURL = internal.DefaultManagementURL
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your servver or use Setup Keys to login", mgmtURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
|
|
||||||
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -200,21 +206,15 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
if !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
err := open.Run(verificationURIComplete)
|
||||||
|
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
verificationURIComplete + " " + codeMsg)
|
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
||||||
cmd.Println("")
|
if err != nil {
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
|
||||||
func isLinuxRunningDesktop() bool {
|
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -25,12 +25,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
externalIPMapFlag = "external-ip-map"
|
externalIPMapFlag = "external-ip-map"
|
||||||
dnsResolverAddress = "dns-resolver-address"
|
dnsResolverAddress = "dns-resolver-address"
|
||||||
enableRosenpassFlag = "enable-rosenpass"
|
|
||||||
preSharedKeyFlag = "preshared-key"
|
|
||||||
interfaceNameFlag = "interface-name"
|
|
||||||
wireguardPortFlag = "wireguard-port"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -53,9 +49,6 @@ var (
|
|||||||
preSharedKey string
|
preSharedKey string
|
||||||
natExternalIPs []string
|
natExternalIPs []string
|
||||||
customDNSAddress string
|
customDNSAddress string
|
||||||
rosenpassEnabled bool
|
|
||||||
interfaceName string
|
|
||||||
wireguardPort uint16
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
Short: "",
|
Short: "",
|
||||||
@@ -99,9 +92,9 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
|
||||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||||
rootCmd.AddCommand(serviceCmd)
|
rootCmd.AddCommand(serviceCmd)
|
||||||
rootCmd.AddCommand(upCmd)
|
rootCmd.AddCommand(upCmd)
|
||||||
@@ -125,7 +118,6 @@ func init() {
|
|||||||
`An empty string "" clears the previous configuration. `+
|
`An empty string "" clears the previous configuration. `+
|
||||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||||
)
|
)
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
|
|||||||
@@ -73,8 +73,7 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Debug(err)
|
log.Print(err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -93,10 +92,12 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
|||||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Error: %v\n", err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
cmd.Printf("Couldn't connect. " +
|
||||||
"You can verify the connection by running:\n\n" +
|
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
||||||
" netbird status\n\n")
|
"Run the status command: \n\n" +
|
||||||
return err
|
" netbird status\n\n" +
|
||||||
|
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|||||||
@@ -66,15 +66,13 @@ type statusOutputOverview struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
detailFlag bool
|
detailFlag bool
|
||||||
ipv4Flag bool
|
ipv4Flag bool
|
||||||
jsonFlag bool
|
jsonFlag bool
|
||||||
yamlFlag bool
|
yamlFlag bool
|
||||||
ipsFilter []string
|
ipsFilter []string
|
||||||
prefixNamesFilter []string
|
statusFilter string
|
||||||
statusFilter string
|
ipsFilterMap map[string]struct{}
|
||||||
ipsFilterMap map[string]struct{}
|
|
||||||
prefixNamesFilterMap map[string]struct{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var statusCmd = &cobra.Command{
|
var statusCmd = &cobra.Command{
|
||||||
@@ -85,14 +83,12 @@ var statusCmd = &cobra.Command{
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
ipsFilterMap = make(map[string]struct{})
|
ipsFilterMap = make(map[string]struct{})
|
||||||
prefixNamesFilterMap = make(map[string]struct{})
|
|
||||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
|
||||||
resp, err := getStatus(ctx, cmd)
|
resp, _ := getStatus(ctx, cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||||
@@ -124,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||||
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
|
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
||||||
resp.GetStatus(),
|
resp.GetStatus(),
|
||||||
)
|
)
|
||||||
return nil
|
return nil
|
||||||
@@ -137,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||||
|
|
||||||
var statusOutputString string
|
statusOutputString := ""
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||||
@@ -176,12 +172,8 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
|
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "disconnected", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
|
||||||
enableDetailFlagWhenFilterFlag()
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||||
}
|
}
|
||||||
@@ -193,26 +185,11 @@ func parseFilters() error {
|
|||||||
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
|
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
|
||||||
}
|
}
|
||||||
ipsFilterMap[addr] = struct{}{}
|
ipsFilterMap[addr] = struct{}{}
|
||||||
enableDetailFlagWhenFilterFlag()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(prefixNamesFilter) > 0 {
|
|
||||||
for _, name := range prefixNamesFilter {
|
|
||||||
prefixNamesFilterMap[strings.ToLower(name)] = struct{}{}
|
|
||||||
}
|
|
||||||
enableDetailFlagWhenFilterFlag()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableDetailFlagWhenFilterFlag() {
|
|
||||||
if !detailFlag && !jsonFlag && !yamlFlag {
|
|
||||||
detailFlag = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||||
pbFullStatus := resp.GetFullStatus()
|
pbFullStatus := resp.GetFullStatus()
|
||||||
|
|
||||||
@@ -257,7 +234,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if isPeerConnected {
|
if isPeerConnected {
|
||||||
peersConnected++
|
peersConnected = peersConnected + 1
|
||||||
|
|
||||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||||
@@ -430,7 +407,7 @@ func parsePeers(peers peersStateOutput) string {
|
|||||||
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||||
)
|
)
|
||||||
|
|
||||||
peersString += peerString
|
peersString = peersString + peerString
|
||||||
}
|
}
|
||||||
return peersString
|
return peersString
|
||||||
}
|
}
|
||||||
@@ -438,7 +415,6 @@ func parsePeers(peers peersStateOutput) string {
|
|||||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||||
statusEval := false
|
statusEval := false
|
||||||
ipEval := false
|
ipEval := false
|
||||||
nameEval := false
|
|
||||||
|
|
||||||
if statusFilter != "" {
|
if statusFilter != "" {
|
||||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||||
@@ -455,15 +431,5 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
|||||||
ipEval = true
|
ipEval = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return statusEval || ipEval
|
||||||
if len(prefixNamesFilter) > 0 {
|
|
||||||
for prefixNameFilter := range prefixNamesFilterMap {
|
|
||||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
|
||||||
nameEval = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,27 +2,24 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"net"
|
"net"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
mgmt "github.com/netbirdio/netbird/management/server"
|
mgmt "github.com/netbirdio/netbird/management/server"
|
||||||
sigProto "github.com/netbirdio/netbird/signal/proto"
|
sigProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
sig "github.com/netbirdio/netbird/signal/server"
|
sig "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startTestingServices(t *testing.T) string {
|
func startTestingServices(t *testing.T) string {
|
||||||
t.Helper()
|
|
||||||
config := &mgmt.Config{}
|
config := &mgmt.Config{}
|
||||||
_, err := util.ReadJson("../testdata/management.json", config)
|
_, err := util.ReadJson("../testdata/management.json", config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -45,7 +42,6 @@ func startTestingServices(t *testing.T) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -62,29 +58,28 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
store, err := mgmt.NewFileStore(config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||||
eventStore, false)
|
eventStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -101,7 +96,6 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
func startClientDaemon(
|
func startClientDaemon(
|
||||||
t *testing.T, ctx context.Context, managementURL, configPath string,
|
t *testing.T, ctx context.Context, managementURL, configPath string,
|
||||||
) (*grpc.Server, net.Listener) {
|
) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -17,7 +16,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,8 +36,6 @@ var (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -89,24 +85,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
NATExternalIPs: natExternalIPs,
|
NATExternalIPs: natExternalIPs,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
}
|
}
|
||||||
|
if preSharedKey != "" {
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
|
||||||
ic.RosenpassEnabled = &rosenpassEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ic.InterfaceName = &interfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(wireguardPortFlag).Changed {
|
|
||||||
p := int(wireguardPort)
|
|
||||||
ic.WireguardPort = &p
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
|
||||||
ic.PreSharedKey = &preSharedKey
|
ic.PreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -125,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
@@ -144,7 +123,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
defer func() {
|
defer func() {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
log.Warnf("failed closing dameon gRPC client connection %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -162,31 +141,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: setupKey,
|
SetupKey: setupKey,
|
||||||
PreSharedKey: preSharedKey,
|
PreSharedKey: preSharedKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
|
||||||
Hostname: hostName,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
|
||||||
loginRequest.RosenpassEnabled = &rosenpassEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
loginRequest.InterfaceName = &interfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(wireguardPortFlag).Changed {
|
|
||||||
wp := int64(wireguardPort)
|
|
||||||
loginRequest.WireguardPort = &wp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
@@ -217,7 +178,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||||
}
|
}
|
||||||
@@ -238,11 +199,11 @@ func validateNATExternalIPs(list []string) error {
|
|||||||
|
|
||||||
subElements := strings.Split(element, "/")
|
subElements := strings.Split(element, "/")
|
||||||
if len(subElements) > 2 {
|
if len(subElements) > 2 {
|
||||||
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(subElements) == 1 && !isValidIP(subElements[0]) {
|
if len(subElements) == 1 && !isValidIP(subElements[0]) {
|
||||||
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
||||||
}
|
}
|
||||||
|
|
||||||
last := 0
|
last := 0
|
||||||
@@ -260,18 +221,6 @@ func validateNATExternalIPs(list []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseInterfaceName(name string) error {
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(name, "utun") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("invalid interface name %s. Please use the prefix utun followed by a number on MacOS. e.g., utun1 or utun199", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateElement(element string) (int, error) {
|
func validateElement(element string) (int, error) {
|
||||||
if isValidIP(element) {
|
if isValidIP(element) {
|
||||||
return ipInputType, nil
|
return ipInputType, nil
|
||||||
@@ -309,7 +258,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
|||||||
var parsed []byte
|
var parsed []byte
|
||||||
if modified {
|
if modified {
|
||||||
if !isValidAddrPort(customDNSAddress) {
|
if !isValidAddrPort(customDNSAddress) {
|
||||||
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||||
}
|
}
|
||||||
if customDNSAddress == "" && logFile != "console" {
|
if customDNSAddress == "" && logFile != "console" {
|
||||||
parsed = []byte("empty")
|
parsed = []byte("empty")
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package firewall
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
|
||||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
|
||||||
if !iface.IsUserspaceBind() {
|
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
|
||||||
fm, err := uspfilter.Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = fm.AllowNetbird()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
return fm, nil
|
|
||||||
}
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package firewall
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
|
||||||
UNKNOWN FWType = iota
|
|
||||||
// IPTABLES is the value for the iptables firewall type
|
|
||||||
IPTABLES
|
|
||||||
// NFTABLES is the value for the nftables firewall type
|
|
||||||
NFTABLES
|
|
||||||
)
|
|
||||||
|
|
||||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
|
||||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|
||||||
|
|
||||||
// FWType is the type for the firewall type
|
|
||||||
type FWType int
|
|
||||||
|
|
||||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
|
||||||
// on the linux system we try to user nftables or iptables
|
|
||||||
// in any case, because we need to allow netbird interface traffic
|
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
|
||||||
// for the userspace packet filtering firewall
|
|
||||||
var fm firewall.Manager
|
|
||||||
var errFw error
|
|
||||||
|
|
||||||
switch check() {
|
|
||||||
case IPTABLES:
|
|
||||||
log.Debug("creating an iptables firewall manager")
|
|
||||||
fm, errFw = nbiptables.Create(context, iface)
|
|
||||||
if errFw != nil {
|
|
||||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
|
||||||
}
|
|
||||||
case NFTABLES:
|
|
||||||
log.Debug("creating an nftables firewall manager")
|
|
||||||
fm, errFw = nbnftables.Create(context, iface)
|
|
||||||
if errFw != nil {
|
|
||||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
errFw = fmt.Errorf("no firewall manager found")
|
|
||||||
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
|
|
||||||
}
|
|
||||||
|
|
||||||
if iface.IsUserspaceBind() {
|
|
||||||
var errUsp error
|
|
||||||
if errFw == nil {
|
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
|
||||||
} else {
|
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
|
||||||
}
|
|
||||||
if errUsp != nil {
|
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
|
||||||
return nil, errUsp
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
return fm, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if errFw != nil {
|
|
||||||
return nil, errFw
|
|
||||||
}
|
|
||||||
|
|
||||||
return fm, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
|
||||||
func check() FWType {
|
|
||||||
nf := nftables.Conn{}
|
|
||||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
|
||||||
return NFTABLES
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return UNKNOWN
|
|
||||||
}
|
|
||||||
if isIptablesClientAvailable(ip) {
|
|
||||||
return IPTABLES
|
|
||||||
}
|
|
||||||
|
|
||||||
return UNKNOWN
|
|
||||||
}
|
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
57
client/firewall/firewall.go
Normal file
57
client/firewall/firewall.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rule abstraction should be implemented by each firewall manager
|
||||||
|
//
|
||||||
|
// Each firewall type for different OS can use different type
|
||||||
|
// of the properties to hold data of the created rule
|
||||||
|
type Rule interface {
|
||||||
|
// GetRuleID returns the rule id
|
||||||
|
GetRuleID() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direction is the direction of the traffic
|
||||||
|
type Direction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DirectionSrc is the direction of the traffic from the source
|
||||||
|
DirectionSrc Direction = iota
|
||||||
|
// DirectionDst is the direction of the traffic from the destination
|
||||||
|
DirectionDst
|
||||||
|
)
|
||||||
|
|
||||||
|
// Action is the action to be taken on a rule
|
||||||
|
type Action int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ActionAccept is the action to accept a packet
|
||||||
|
ActionAccept Action = iota
|
||||||
|
// ActionDrop is the action to drop a packet
|
||||||
|
ActionDrop
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager is the high level abstraction of a firewall manager
|
||||||
|
//
|
||||||
|
// It declares methods which handle actions required by the
|
||||||
|
// Netbird client for ACL and routing functionality
|
||||||
|
type Manager interface {
|
||||||
|
// AddFiltering rule to the firewall
|
||||||
|
AddFiltering(
|
||||||
|
ip net.IP,
|
||||||
|
port *Port,
|
||||||
|
direction Direction,
|
||||||
|
action Action,
|
||||||
|
comment string,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
|
// DeleteRule from the firewall by rule definition
|
||||||
|
DeleteRule(rule Rule) error
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
Reset() error
|
||||||
|
|
||||||
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
|
}
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package firewall
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/iface"
|
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
SetFilter(iface.PacketFilter) error
|
|
||||||
}
|
|
||||||
@@ -1,473 +0,0 @@
|
|||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/nadoo/ipset"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
tableName = "filter"
|
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
|
||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
|
||||||
|
|
||||||
postRoutingMark = "0x000007e4"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aclManager struct {
|
|
||||||
iptablesClient *iptables.IPTables
|
|
||||||
wgIface iFaceMapper
|
|
||||||
routeingFwChainName string
|
|
||||||
|
|
||||||
entries map[string][][]string
|
|
||||||
ipsetStore *ipsetStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
|
||||||
m := &aclManager{
|
|
||||||
iptablesClient: iptablesClient,
|
|
||||||
wgIface: wgIface,
|
|
||||||
routeingFwChainName: routeingFwChainName,
|
|
||||||
|
|
||||||
entries: make(map[string][][]string),
|
|
||||||
ipsetStore: newIpsetStore(),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ipset.Init()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.seedInitialEntries()
|
|
||||||
|
|
||||||
err = m.cleanChains()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.createDefaultChains()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) AddFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
var dPortVal, sPortVal string
|
|
||||||
if dPort != nil && dPort.Values != nil {
|
|
||||||
// TODO: we support only one port per rule in current implementation of ACLs
|
|
||||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
|
||||||
}
|
|
||||||
if sPort != nil && sPort.Values != nil {
|
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var chain string
|
|
||||||
if direction == firewall.RuleDirectionOUT {
|
|
||||||
chain = chainNameOutputRules
|
|
||||||
} else {
|
|
||||||
chain = chainNameInputRules
|
|
||||||
}
|
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
|
||||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
|
||||||
if ipsetName != "" {
|
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
|
||||||
}
|
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
|
||||||
ipList.addIP(ip.String())
|
|
||||||
return []firewall.Rule{&Rule{
|
|
||||||
ruleID: uuid.New().String(),
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
chain: chain,
|
|
||||||
specs: specs,
|
|
||||||
}}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
|
||||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
|
||||||
}
|
|
||||||
if err := ipset.Create(ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
|
||||||
}
|
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipList := newIpList(ip.String())
|
|
||||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
ruleID: uuid.New().String(),
|
|
||||||
specs: specs,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
chain: chain,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
|
||||||
return []firewall.Rule{rule}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
|
||||||
if err != nil {
|
|
||||||
return []firewall.Rule{rule}, err
|
|
||||||
}
|
|
||||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
|
||||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid rule type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.chain == "PREROUTING" {
|
|
||||||
goto DELETERULE
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
|
||||||
// delete IP from ruleset IPs list and ipset
|
|
||||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
|
||||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
|
||||||
}
|
|
||||||
delete(ipsetList.ips, r.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if after delete, set still contains other IPs,
|
|
||||||
// no need to delete firewall rule and we should exit here
|
|
||||||
if len(ipsetList.ips) != 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// we delete last IP from the set, that means we need to delete
|
|
||||||
// set itself and associated firewall rule too
|
|
||||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
|
||||||
|
|
||||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
|
||||||
log.Errorf("delete empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DELETERULE:
|
|
||||||
var table string
|
|
||||||
if r.chain == "PREROUTING" {
|
|
||||||
table = "mangle"
|
|
||||||
} else {
|
|
||||||
table = "filter"
|
|
||||||
}
|
|
||||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) Reset() error {
|
|
||||||
return m.cleanChains()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
|
||||||
var src []string
|
|
||||||
if ipsetName != "" {
|
|
||||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
|
||||||
} else {
|
|
||||||
src = []string{"-s", ip.String()}
|
|
||||||
}
|
|
||||||
specs := []string{
|
|
||||||
"-d", m.wgIface.Address().IP.String(),
|
|
||||||
"-p", protocol,
|
|
||||||
"--dport", port,
|
|
||||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
|
||||||
}
|
|
||||||
|
|
||||||
specs = append(src, specs...)
|
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
ruleID: uuid.New().String(),
|
|
||||||
specs: specs,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
chain: "PREROUTING",
|
|
||||||
}
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
|
||||||
func (m *aclManager) cleanChains() error {
|
|
||||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
rules := m.entries["OUTPUT"]
|
|
||||||
for _, rule := range rules {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
for _, rule := range m.entries["INPUT"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range m.entries["FORWARD"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
for _, rule := range m.entries["PREROUTING"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
if err := ipset.Destroy(ipsetName); err != nil {
|
|
||||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
m.ipsetStore.deleteIpset(ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) createDefaultChains() error {
|
|
||||||
// chain netbird-acl-input-rules
|
|
||||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
|
||||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// chain netbird-acl-output-rules
|
|
||||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
|
||||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if chainName == "FORWARD" {
|
|
||||||
// position 2 because we add it after router's, jump rule
|
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) seedInitialEntries() {
|
|
||||||
m.appendToEntries("INPUT",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT",
|
|
||||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT",
|
|
||||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT",
|
|
||||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
|
||||||
m.appendToEntries("FORWARD",
|
|
||||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("FORWARD",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
|
||||||
|
|
||||||
m.appendToEntries("PREROUTING",
|
|
||||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
|
||||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
|
||||||
func filterRuleSpecs(
|
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
|
||||||
) (specs []string) {
|
|
||||||
matchByIP := true
|
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
|
||||||
if ip.String() == "0.0.0.0" {
|
|
||||||
matchByIP = false
|
|
||||||
}
|
|
||||||
switch direction {
|
|
||||||
case firewall.RuleDirectionIN:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-s", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case firewall.RuleDirectionOUT:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-d", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if protocol != "all" {
|
|
||||||
specs = append(specs, "-p", protocol)
|
|
||||||
}
|
|
||||||
if sPort != "" {
|
|
||||||
specs = append(specs, "--sport", sPort)
|
|
||||||
}
|
|
||||||
if dPort != "" {
|
|
||||||
specs = append(specs, "--dport", dPort)
|
|
||||||
}
|
|
||||||
return append(specs, "-j", actionToStr(action))
|
|
||||||
}
|
|
||||||
|
|
||||||
func actionToStr(action firewall.Action) string {
|
|
||||||
if action == firewall.ActionAccept {
|
|
||||||
return "ACCEPT"
|
|
||||||
}
|
|
||||||
return "DROP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
|
||||||
switch {
|
|
||||||
case ipsetName == "":
|
|
||||||
return ""
|
|
||||||
case sPort != "" && dPort != "":
|
|
||||||
return ipsetName + "-sport-dport"
|
|
||||||
case sPort != "":
|
|
||||||
return ipsetName + "-sport"
|
|
||||||
case dPort != "":
|
|
||||||
return ipsetName + "-dport"
|
|
||||||
default:
|
|
||||||
return ipsetName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
|
||||||
if proto == "all" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if direction != firewall.RuleDirectionIN {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -1,155 +1,160 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/iface"
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ChainFilterName is the name of the chain that is used for filtering by the Netbird client
|
||||||
|
ChainFilterName = "NETBIRD-ACL"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
wgIface iFaceMapper
|
|
||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
aclMgr *aclManager
|
ipv6Client *iptables.IPTables
|
||||||
router *routerManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
func Create() (*Manager, error) {
|
||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
m := &Manager{}
|
||||||
|
|
||||||
|
// init clients for booth ipv4 and ipv6
|
||||||
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
|
m.ipv4Client = ipv4Client
|
||||||
|
|
||||||
m := &Manager{
|
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
wgIface: wgIface,
|
if err != nil {
|
||||||
ipv4Client: iptablesClient,
|
return nil, fmt.Errorf("ip6tables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
|
m.ipv6Client = ipv6Client
|
||||||
|
|
||||||
m.router, err = newRouterManager(context, iptablesClient)
|
if err := m.Reset(); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("failed to reset firewall: %s", err)
|
||||||
log.Debugf("failed to initialize route related chains: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
|
||||||
// Comment will be ignored because some system this feature is not supported
|
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
port *fw.Port,
|
||||||
sPort *firewall.Port,
|
direction fw.Direction,
|
||||||
dPort *firewall.Port,
|
action fw.Action,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
client := m.client(ip)
|
||||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||||
|
return nil, fmt.Errorf("invalid port definition")
|
||||||
|
}
|
||||||
|
pv := strconv.Itoa(port.Values[0])
|
||||||
|
if port.IsRange {
|
||||||
|
pv += ":" + strconv.Itoa(port.Values[1])
|
||||||
|
}
|
||||||
|
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||||
|
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rule := &Rule{
|
||||||
|
id: uuid.New().String(),
|
||||||
|
specs: specs,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}
|
||||||
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
return m.aclMgr.DeleteRule(rule)
|
if !ok {
|
||||||
}
|
return fmt.Errorf("invalid rule type")
|
||||||
|
}
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
client := m.ipv4Client
|
||||||
return true
|
if r.v6 {
|
||||||
}
|
client = m.ipv6Client
|
||||||
|
}
|
||||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
return client.Delete("filter", ChainFilterName, r.specs...)
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.InsertRoutingRules(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.RemoveRoutingRules(pair)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil {
|
||||||
errAcl := m.aclMgr.Reset()
|
return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err)
|
||||||
if errAcl != nil {
|
|
||||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
|
||||||
}
|
}
|
||||||
errMgr := m.router.Reset()
|
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||||
if errMgr != nil {
|
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
|
||||||
return errMgr
|
|
||||||
}
|
}
|
||||||
return errAcl
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// reset firewall chain, clear it and drop it
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) reset(client *iptables.IPTables, table, chain string) error {
|
||||||
if !m.wgIface.IsUserspaceBind() {
|
ok, err := client.ChainExists(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check if chain exists: %w", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if err := client.ClearChain(table, ChainFilterName); err != nil {
|
||||||
_, err := m.AddFiltering(
|
return fmt.Errorf("failed to clear chain: %w", err)
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.RuleDirectionIN,
|
|
||||||
firewall.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
|
||||||
}
|
}
|
||||||
_, err = m.AddFiltering(
|
return client.DeleteChain(table, ChainFilterName)
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.RuleDirectionOUT,
|
|
||||||
firewall.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) filterRuleSpecs(
|
||||||
|
table string, chain string, ip net.IP, port string,
|
||||||
|
direction fw.Direction, action fw.Action, comment string,
|
||||||
|
) (specs []string) {
|
||||||
|
if direction == fw.DirectionSrc {
|
||||||
|
specs = append(specs, "-s", ip.String())
|
||||||
|
}
|
||||||
|
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||||
|
specs = append(specs, "-j", m.actionToStr(action))
|
||||||
|
return append(specs, "-m", "comment", "--comment", comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
// client returns corresponding iptables client for the given ip
|
||||||
|
func (m *Manager) client(ip net.IP) *iptables.IPTables {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return m.ipv4Client
|
||||||
|
}
|
||||||
|
return m.ipv6Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) actionToStr(action fw.Action) string {
|
||||||
|
if action == fw.ActionAccept {
|
||||||
|
return "ACCEPT"
|
||||||
|
}
|
||||||
|
return "DROP"
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,282 +1,105 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/stretchr/testify/require"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
func TestNewManager(t *testing.T) {
|
||||||
type iFaceMock struct {
|
|
||||||
NameFunc func() string
|
|
||||||
AddressFunc func() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
|
||||||
if i.NameFunc != nil {
|
|
||||||
return i.NameFunc()
|
|
||||||
}
|
|
||||||
panic("NameFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
|
||||||
if i.AddressFunc != nil {
|
|
||||||
return i.AddressFunc()
|
|
||||||
}
|
|
||||||
panic("AddressFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|
||||||
|
|
||||||
func TestIptablesManager(t *testing.T) {
|
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
manager, err := Create()
|
||||||
manager, err := Create(context.Background(), mock)
|
if err != nil {
|
||||||
require.NoError(t, err)
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
var rule1 fw.Rule
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
require.NoError(t, err, "clear the manager state")
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
|
Proto: fw.PortProtocolTCP,
|
||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
for _, r := range rule2 {
|
|
||||||
rr := r.(*Rule)
|
|
||||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
for _, r := range rule1 {
|
if err := manager.DeleteRule(rule1); err != nil {
|
||||||
err := manager.DeleteRule(r)
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
if err := manager.DeleteRule(rule2); err != nil {
|
||||||
err := manager.DeleteRule(r)
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = manager.Reset()
|
if err := manager.Reset(); err != nil {
|
||||||
require.NoError(t, err, "failed to reset")
|
t.Errorf("failed to reset: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
ok, err := ipv4Client.ChainExists("filter", ChainFilterName)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
if err != nil {
|
||||||
|
t.Errorf("failed to drop chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
t.Errorf("chain '%v' still exists after Reset", ChainFilterName)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManagerIPSet(t *testing.T) {
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Errorf("failed to check rule: %v", err)
|
||||||
mock := &iFaceMock{
|
return
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
if !exists && mustExists {
|
||||||
manager, err := Create(context.Background(), mock)
|
t.Errorf("rule '%v' does not exist", rulespec)
|
||||||
require.NoError(t, err)
|
return
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
require.NoError(t, err, "clear the manager state")
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule with set", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddFiltering(
|
|
||||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
|
||||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
|
||||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.3")
|
|
||||||
port := &fw.Port{
|
|
||||||
Values: []int{443},
|
|
||||||
}
|
|
||||||
rule2, err = manager.AddFiltering(
|
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
|
||||||
"default", "accept HTTPS traffic from ports range",
|
|
||||||
)
|
|
||||||
for _, r := range rule2 {
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
|
||||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeleteRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule2 {
|
|
||||||
err := manager.DeleteRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "failed to reset")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
|
||||||
t.Helper()
|
|
||||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
|
||||||
require.NoError(t, err, "failed to check rule")
|
|
||||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
|
||||||
require.Falsef(t, exists && !mustExists, "rule '%v' exist", rulespec)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesCreatePerformance(t *testing.T) {
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
if exists && !mustExists {
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
t.Errorf("rule '%v' exist", rulespec)
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
return
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(context.Background(), mock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
require.NoError(t, err, "clear the manager state")
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
|
||||||
start := time.Now()
|
|
||||||
for i := 0; i < testMax; i++ {
|
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
}
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,340 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
|
||||||
ipv4Nat = "netbird-rt-nat"
|
|
||||||
)
|
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
|
||||||
const (
|
|
||||||
tableFilter = "filter"
|
|
||||||
tableNat = "nat"
|
|
||||||
chainFORWARD = "FORWARD"
|
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
|
||||||
routingFinalForwardJump = "ACCEPT"
|
|
||||||
routingFinalNatJump = "MASQUERADE"
|
|
||||||
)
|
|
||||||
|
|
||||||
type routerManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
iptablesClient *iptables.IPTables
|
|
||||||
rules map[string][]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
m := &routerManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
iptablesClient: iptablesClient,
|
|
||||||
rules: make(map[string][]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = m.createContainers()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create containers for route: %s", err)
|
|
||||||
}
|
|
||||||
return m, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
|
||||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pair.Masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertRoutingRule inserts an iptable rule
|
|
||||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
|
||||||
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
|
|
||||||
existingRule, found := i.rules[ruleKey]
|
|
||||||
if found {
|
|
||||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
|
||||||
}
|
|
||||||
delete(i.rules, ruleKey)
|
|
||||||
}
|
|
||||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
i.rules[ruleKey] = rule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
|
||||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pair.Masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
|
||||||
existingRule, found := i.rules[ruleKey]
|
|
||||||
if found {
|
|
||||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(i.rules, ruleKey)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) RouteingFwChainName() string {
|
|
||||||
return chainRTFWD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) Reset() error {
|
|
||||||
err := i.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules = make(map[string][]string)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
|
||||||
err := i.cleanJumpRules()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
|
||||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) createContainers() error {
|
|
||||||
if i.rules[Ipv4Forwarding] != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
errMSGFormat := "failed creating chain %s,error: %v"
|
|
||||||
err := i.createChain(tableFilter, chainRTFWD)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.createChain(tableNat, chainRTNAT)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.addJumpRules()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addJumpRules create jump rules to send packets to NetBird chains
|
|
||||||
func (i *routerManager) addJumpRules() error {
|
|
||||||
rule := []string{"-j", chainRTFWD}
|
|
||||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[Ipv4Forwarding] = rule
|
|
||||||
|
|
||||||
rule = []string{"-j", chainRTNAT}
|
|
||||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv4Nat] = rule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
|
||||||
func (i *routerManager) cleanJumpRules() error {
|
|
||||||
var err error
|
|
||||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
|
||||||
rule, found := i.rules[Ipv4Forwarding]
|
|
||||||
if found {
|
|
||||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rule, found = i.rules[ipv4Nat]
|
|
||||||
if found {
|
|
||||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list rules: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ruleString := range rules {
|
|
||||||
if !strings.Contains(ruleString, "NETBIRD") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rule := strings.Fields(ruleString)
|
|
||||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ruleString := range rules {
|
|
||||||
if !strings.Contains(ruleString, "NETBIRD") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rule := strings.Fields(ruleString)
|
|
||||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) createChain(table, newChain string) error {
|
|
||||||
chains, err := i.iptablesClient.ListChains(table)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldCreateChain := true
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain == newChain {
|
|
||||||
shouldCreateChain = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldCreateChain {
|
|
||||||
err = i.iptablesClient.NewChain(table, newChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genRuleSpec generates rule specification with comment identifier
|
|
||||||
func genRuleSpec(jump, id, source, destination string) []string {
|
|
||||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIptablesRuleType(table string) string {
|
|
||||||
ruleType := "forwarding"
|
|
||||||
if table == tableNat {
|
|
||||||
ruleType = "nat"
|
|
||||||
}
|
|
||||||
return ruleType
|
|
||||||
}
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os/exec"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
|
||||||
_, err4 := exec.LookPath("iptables")
|
|
||||||
return err4 == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|
||||||
if !isIptablesSupported() {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
|
||||||
|
|
||||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
_ = manager.Reset()
|
|
||||||
}()
|
|
||||||
|
|
||||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
|
||||||
ID: "abc",
|
|
||||||
Source: "100.100.100.1/32",
|
|
||||||
Destination: "100.100.100.0/24",
|
|
||||||
Masquerade: true,
|
|
||||||
}
|
|
||||||
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
|
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
|
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
|
||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
|
||||||
|
|
||||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := manager.Reset()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to reset iptables manager: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
|
||||||
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
foundRule, found := manager.rules[forwardRuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
|
||||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
|
||||||
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.True(t, exists, "income forwarding rule should exist")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[inForwardRuleKey]
|
|
||||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
|
||||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
|
||||||
if testCase.InputPair.Masquerade {
|
|
||||||
require.True(t, exists, "nat rule should be created")
|
|
||||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
|
||||||
require.True(t, foundNat, "nat rule should exist in the map")
|
|
||||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
|
||||||
} else {
|
|
||||||
require.False(t, exists, "nat rule should not be created")
|
|
||||||
_, foundNat := manager.rules[natRuleKey]
|
|
||||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
|
||||||
}
|
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
|
||||||
if testCase.InputPair.Masquerade {
|
|
||||||
require.True(t, exists, "income nat rule should be created")
|
|
||||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
|
||||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
|
||||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
|
||||||
} else {
|
|
||||||
require.False(t, exists, "nat rule should not be created")
|
|
||||||
_, foundNat := manager.rules[inNatRuleKey]
|
|
||||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
|
|
||||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
defer func() {
|
|
||||||
_ = manager.Reset()
|
|
||||||
}()
|
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.False(t, exists, "forwarding rule should not exist")
|
|
||||||
|
|
||||||
_, found := manager.rules[forwardRuleKey]
|
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.False(t, exists, "income forwarding rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[inForwardRuleKey]
|
|
||||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
|
||||||
require.False(t, exists, "nat rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[natRuleKey]
|
|
||||||
require.False(t, found, "nat rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
|
||||||
require.False(t, exists, "income nat rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[inNatRuleKey]
|
|
||||||
require.False(t, found, "income nat rule should exist in the manager map")
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,15 +2,12 @@ package iptables
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
ruleID string
|
id string
|
||||||
ipsetName string
|
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
ip string
|
v6 bool
|
||||||
chain string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.ruleID
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
package iptables
|
|
||||||
|
|
||||||
type ipList struct {
|
|
||||||
ips map[string]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIpList(ip string) ipList {
|
|
||||||
ips := make(map[string]struct{})
|
|
||||||
ips[ip] = struct{}{}
|
|
||||||
|
|
||||||
return ipList{
|
|
||||||
ips: ips,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipList) addIP(ip string) {
|
|
||||||
s.ips[ip] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ipsetStore struct {
|
|
||||||
ipsets map[string]ipList // ipsetName -> ruleset
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIpsetStore() *ipsetStore {
|
|
||||||
return &ipsetStore{
|
|
||||||
ipsets: make(map[string]ipList),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
|
||||||
r, ok := s.ipsets[ipsetName]
|
|
||||||
return r, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
|
||||||
s.ipsets[ipsetName] = list
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
|
||||||
s.ipsets[ipsetName] = ipList{}
|
|
||||||
delete(s.ipsets, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) ipsetNames() []string {
|
|
||||||
names := make([]string, 0, len(s.ipsets))
|
|
||||||
for name := range s.ipsets {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
NatFormat = "netbird-nat-%s"
|
|
||||||
ForwardingFormat = "netbird-fwd-%s"
|
|
||||||
InNatFormat = "netbird-nat-in-%s"
|
|
||||||
InForwardingFormat = "netbird-fwd-in-%s"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rule abstraction should be implemented by each firewall manager
|
|
||||||
//
|
|
||||||
// Each firewall type for different OS can use different type
|
|
||||||
// of the properties to hold data of the created rule
|
|
||||||
type Rule interface {
|
|
||||||
// GetRuleID returns the rule id
|
|
||||||
GetRuleID() string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
|
||||||
type RuleDirection int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// RuleDirectionIN applies to filters that handlers incoming traffic
|
|
||||||
RuleDirectionIN RuleDirection = iota
|
|
||||||
// RuleDirectionOUT applies to filters that handlers outgoing traffic
|
|
||||||
RuleDirectionOUT
|
|
||||||
)
|
|
||||||
|
|
||||||
// Action is the action to be taken on a rule
|
|
||||||
type Action int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ActionAccept is the action to accept a packet
|
|
||||||
ActionAccept Action = iota
|
|
||||||
// ActionDrop is the action to drop a packet
|
|
||||||
ActionDrop
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager is the high level abstraction of a firewall manager
|
|
||||||
//
|
|
||||||
// It declares methods which handle actions required by the
|
|
||||||
// Netbird client for ACL and routing functionality
|
|
||||||
type Manager interface {
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
|
||||||
AllowNetbird() error
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
AddFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
proto Protocol,
|
|
||||||
sPort *Port,
|
|
||||||
dPort *Port,
|
|
||||||
direction RuleDirection,
|
|
||||||
action Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]Rule, error)
|
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
|
||||||
DeleteRule(rule Rule) error
|
|
||||||
|
|
||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
|
||||||
IsServerRouteSupported() bool
|
|
||||||
|
|
||||||
// InsertRoutingRules inserts a routing firewall rule
|
|
||||||
InsertRoutingRules(pair RouterPair) error
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes a routing firewall rule
|
|
||||||
RemoveRoutingRules(pair RouterPair) error
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
Reset() error
|
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
|
||||||
Flush() error
|
|
||||||
}
|
|
||||||
|
|
||||||
func GenKey(format string, input string) string {
|
|
||||||
return fmt.Sprintf(format, input)
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Protocol is the protocol of the port
|
|
||||||
type Protocol string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ProtocolTCP is the TCP protocol
|
|
||||||
ProtocolTCP Protocol = "tcp"
|
|
||||||
|
|
||||||
// ProtocolUDP is the UDP protocol
|
|
||||||
ProtocolUDP Protocol = "udp"
|
|
||||||
|
|
||||||
// ProtocolICMP is the ICMP protocol
|
|
||||||
ProtocolICMP Protocol = "icmp"
|
|
||||||
|
|
||||||
// ProtocolALL cover all supported protocols
|
|
||||||
ProtocolALL Protocol = "all"
|
|
||||||
|
|
||||||
// ProtocolUnknown unknown protocol
|
|
||||||
ProtocolUnknown Protocol = "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Port of the address for firewall rule
|
|
||||||
type Port struct {
|
|
||||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
|
||||||
IsRange bool
|
|
||||||
|
|
||||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
|
||||||
Values []int
|
|
||||||
}
|
|
||||||
|
|
||||||
// String interface implementation
|
|
||||||
func (p *Port) String() string {
|
|
||||||
var ports string
|
|
||||||
for _, port := range p.Values {
|
|
||||||
if ports != "" {
|
|
||||||
ports += ","
|
|
||||||
}
|
|
||||||
ports += strconv.Itoa(port)
|
|
||||||
}
|
|
||||||
return ports
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
type RouterPair struct {
|
|
||||||
ID string
|
|
||||||
Source string
|
|
||||||
Destination string
|
|
||||||
Masquerade bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetInPair(pair RouterPair) RouterPair {
|
|
||||||
return RouterPair{
|
|
||||||
ID: pair.ID,
|
|
||||||
// invert Source/Destination
|
|
||||||
Source: pair.Destination,
|
|
||||||
Destination: pair.Source,
|
|
||||||
Masquerade: pair.Masquerade,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,85 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ipsetStore struct {
|
|
||||||
ipsetReference map[string]int
|
|
||||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIpsetStore() *ipsetStore {
|
|
||||||
return &ipsetStore{
|
|
||||||
ipsetReference: make(map[string]int),
|
|
||||||
ipsets: make(map[string]map[string]struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
|
||||||
r, ok := s.ipsets[ipsetName]
|
|
||||||
return r, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
|
||||||
s.ipsetReference[ipsetName] = 0
|
|
||||||
ipList := make(map[string]struct{})
|
|
||||||
s.ipsets[ipsetName] = ipList
|
|
||||||
return ipList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
|
||||||
delete(s.ipsetReference, ipsetName)
|
|
||||||
delete(s.ipsets, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
|
||||||
ipList, ok := s.ipsets[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(ipList, ip.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
|
||||||
ipList, ok := s.ipsets[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipList[ip.String()] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
|
||||||
ipList, ok := s.ipsets[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, ok = ipList[ip.String()]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
|
||||||
s.ipsetReference[ipsetName]++
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
|
||||||
r, ok := s.ipsetReference[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.ipsetReference[ipsetName]--
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
|
||||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.ipsetReference[ipsetName] == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
|
||||||
tableName = "netbird"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager of iptables firewall
|
|
||||||
type Manager struct {
|
|
||||||
mutex sync.Mutex
|
|
||||||
rConn *nftables.Conn
|
|
||||||
wgIface iFaceMapper
|
|
||||||
|
|
||||||
router *router
|
|
||||||
aclManager *AclManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create nftables firewall manager
|
|
||||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|
||||||
m := &Manager{
|
|
||||||
rConn: &nftables.Conn{},
|
|
||||||
wgIface: wgIface,
|
|
||||||
}
|
|
||||||
|
|
||||||
workTable, err := m.createWorkTable()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.router, err = newRouter(context, workTable)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
func (m *Manager) AddFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
rawIP := ip.To4()
|
|
||||||
if rawIP == nil {
|
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
|
||||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.aclManager.DeleteRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.InsertRoutingRules(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.router.RemoveRoutingRules(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
if !m.wgIface.IsUserspaceBind() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
err := m.aclManager.createDefaultAllowRules()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var chain *nftables.Chain
|
|
||||||
for _, c := range chains {
|
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
|
||||||
chain = c
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if chain == nil {
|
|
||||||
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
|
||||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.applyAllowNetbirdRules(chain)
|
|
||||||
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChains()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range chains {
|
|
||||||
// delete Netbird allow input traffic rule if it exists
|
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
|
||||||
rules, err := m.rConn.GetRules(c.Table, c)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, r := range rules {
|
|
||||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
|
||||||
if err := m.rConn.DelRule(r); err != nil {
|
|
||||||
log.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.router.ResetForwardRules()
|
|
||||||
|
|
||||||
tables, err := m.rConn.ListTables()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
|
||||||
}
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableName {
|
|
||||||
m.rConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
|
||||||
//
|
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
|
||||||
// todo review this method usage
|
|
||||||
func (m *Manager) Flush() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.aclManager.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableName {
|
|
||||||
m.rConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
return table, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|
||||||
rule := &nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
UserData: []byte(allowNetbirdInputRuleID),
|
|
||||||
}
|
|
||||||
_ = m.rConn.InsertRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
|
||||||
ifName := ifname(m.wgIface.Name())
|
|
||||||
for _, rule := range existedRules {
|
|
||||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
|
||||||
if len(rule.Exprs) < 4 {
|
|
||||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMock struct {
|
|
||||||
NameFunc func() string
|
|
||||||
AddressFunc func() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
|
||||||
if i.NameFunc != nil {
|
|
||||||
return i.NameFunc()
|
|
||||||
}
|
|
||||||
panic("NameFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
|
||||||
if i.AddressFunc != nil {
|
|
||||||
return i.AddressFunc()
|
|
||||||
}
|
|
||||||
panic("AddressFunc is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|
||||||
|
|
||||||
func TestNftablesManager(t *testing.T) {
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(context.Background(), mock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "failed to reset")
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
rule, err := manager.AddFiltering(
|
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{53}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionDrop,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
err = manager.Flush()
|
|
||||||
require.NoError(t, err, "failed to flush")
|
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
|
||||||
require.NoError(t, err, "failed to get rules")
|
|
||||||
|
|
||||||
require.Len(t, rules, 1, "expected 1 rules")
|
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
expectedExprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(9),
|
|
||||||
Len: uint32(1),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Data: []byte{unix.IPPROTO_TCP},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 53},
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
||||||
}
|
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
|
||||||
|
|
||||||
for _, r := range rule {
|
|
||||||
err = manager.DeleteRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = manager.Flush()
|
|
||||||
require.NoError(t, err, "failed to flush")
|
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
|
||||||
require.NoError(t, err, "failed to get rules")
|
|
||||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "failed to reset")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNFtablesCreatePerformance(t *testing.T) {
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(context.Background(), mock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := manager.Reset(); err != nil {
|
|
||||||
t.Errorf("clear the manager state: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
|
||||||
start := time.Now()
|
|
||||||
for i := 0; i < testMax; i++ {
|
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
if i%100 == 0 {
|
|
||||||
err = manager.Flush()
|
|
||||||
require.NoError(t, err, "failed to flush")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,413 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
chainNameRouteingFw = "netbird-rt-fwd"
|
|
||||||
chainNameRoutingNat = "netbird-rt-nat"
|
|
||||||
|
|
||||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
|
||||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
|
||||||
)
|
|
||||||
|
|
||||||
// some presets for building nftable rules
|
|
||||||
var (
|
|
||||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
|
||||||
|
|
||||||
exprCounterAccept = []expr.Any{
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type router struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
conn *nftables.Conn
|
|
||||||
workTable *nftables.Table
|
|
||||||
filterTable *nftables.Table
|
|
||||||
chains map[string]*nftables.Chain
|
|
||||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
|
||||||
rules map[string]*nftables.Rule
|
|
||||||
isDefaultFwdRulesEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
r := &router{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
workTable: workTable,
|
|
||||||
chains: make(map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
r.filterTable, err = r.loadFilterTable()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, errFilterTableNotFound) {
|
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.createContainers()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create containers for route: %s", err)
|
|
||||||
}
|
|
||||||
return r, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) RouteingFwChainName() string {
|
|
||||||
return chainNameRouteingFw
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
|
||||||
func (r *router) ResetForwardRules() {
|
|
||||||
err := r.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to reset forward rules: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range tables {
|
|
||||||
if table.Name == "filter" {
|
|
||||||
return table, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errFilterTableNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
|
||||||
|
|
||||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRouteingFw,
|
|
||||||
Table: r.workTable,
|
|
||||||
})
|
|
||||||
|
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRoutingNat,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: nftables.ChainPriorityNATSource - 1,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
|
|
||||||
err := r.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
|
||||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
|
||||||
err := r.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
|
||||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
|
||||||
log.Debugf("add default accept forward rule")
|
|
||||||
r.acceptForwardRule(pair.Source)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
|
||||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
|
||||||
|
|
||||||
var expression []expr.Any
|
|
||||||
if isNat {
|
|
||||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
|
||||||
} else {
|
|
||||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleKey := manager.GenKey(format, pair.ID)
|
|
||||||
|
|
||||||
_, exists := r.rules[ruleKey]
|
|
||||||
if exists {
|
|
||||||
err := r.removeRoutingRule(format, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainName],
|
|
||||||
Exprs: expression,
|
|
||||||
UserData: []byte(ruleKey),
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
|
||||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
|
||||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
|
||||||
|
|
||||||
var exprs []expr.Any
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
rule := &nftables.Rule{
|
|
||||||
Table: r.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
|
||||||
}
|
|
||||||
|
|
||||||
r.conn.AddRule(rule)
|
|
||||||
|
|
||||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
|
||||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
|
||||||
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
rule = &nftables.Rule{
|
|
||||||
Table: r.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(rule)
|
|
||||||
r.isDefaultFwdRulesEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
|
||||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
|
||||||
err := r.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(r.rules) == 0 {
|
|
||||||
err := r.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
|
||||||
}
|
|
||||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
|
||||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
|
||||||
ruleKey := manager.GenKey(format, pair.ID)
|
|
||||||
|
|
||||||
rule, found := r.rules[ruleKey]
|
|
||||||
if found {
|
|
||||||
ruleType := "forwarding"
|
|
||||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
|
||||||
ruleType = "nat"
|
|
||||||
}
|
|
||||||
|
|
||||||
err := r.conn.DelRule(rule)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
|
||||||
func (r *router) refreshRulesMap() error {
|
|
||||||
for _, chain := range r.chains {
|
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
r.rules[string(rule.UserData)] = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) cleanUpDefaultForwardRules() error {
|
|
||||||
if r.filterTable == nil {
|
|
||||||
r.isDefaultFwdRulesEnabled = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []*nftables.Rule
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if chain.Name != "FORWARD" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
|
||||||
err := r.conn.DelRule(rule)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.isDefaultFwdRulesEnabled = false
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
|
||||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
|
||||||
ip, network, _ := net.ParseCIDR(cidr)
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
var offSet uint32
|
|
||||||
if source {
|
|
||||||
offSet = 12 // src offset
|
|
||||||
} else {
|
|
||||||
offSet = 16 // dst offset
|
|
||||||
}
|
|
||||||
|
|
||||||
return []expr.Any{
|
|
||||||
// fetch src add
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: offSet,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
// net mask
|
|
||||||
&expr.Bitwise{
|
|
||||||
DestRegister: 1,
|
|
||||||
SourceRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: network.Mask,
|
|
||||||
Xor: zeroXor,
|
|
||||||
},
|
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,280 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
|
||||||
UNKNOWN = iota
|
|
||||||
// IPTABLES is the value for the iptables firewall type
|
|
||||||
IPTABLES
|
|
||||||
// NFTABLES is the value for the nftables firewall type
|
|
||||||
NFTABLES
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this OS")
|
|
||||||
}
|
|
||||||
|
|
||||||
table, err := createWorkTable()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
|
||||||
manager, err := newRouter(context.TODO(), table)
|
|
||||||
require.NoError(t, err, "failed to create router")
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer manager.ResetForwardRules()
|
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
|
||||||
defer func() {
|
|
||||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
|
||||||
}()
|
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
|
||||||
|
|
||||||
found := 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
|
||||||
found := 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
|
||||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
|
|
||||||
found = 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
|
||||||
found := 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this OS")
|
|
||||||
}
|
|
||||||
|
|
||||||
table, err := createWorkTable()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
|
||||||
manager, err := newRouter(context.TODO(), table)
|
|
||||||
require.NoError(t, err, "failed to create router")
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer manager.ResetForwardRules()
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
|
||||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRouteingFw],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(forwardRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
|
||||||
|
|
||||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRoutingNat],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
|
||||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRouteingFw],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(inForwardRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
|
||||||
|
|
||||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRoutingNat],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(inNatRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
err = nftablesTestingClient.Flush()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
manager.ResetForwardRules()
|
|
||||||
|
|
||||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
|
||||||
func check() int {
|
|
||||||
nf := nftables.Conn{}
|
|
||||||
if _, err := nf.ListChains(); err == nil {
|
|
||||||
return NFTABLES
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return UNKNOWN
|
|
||||||
}
|
|
||||||
if isIptablesClientAvailable(ip) {
|
|
||||||
return IPTABLES
|
|
||||||
}
|
|
||||||
|
|
||||||
return UNKNOWN
|
|
||||||
}
|
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createWorkTable() (*nftables.Table, error) {
|
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableName {
|
|
||||||
sConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
|
||||||
err = sConn.Flush()
|
|
||||||
|
|
||||||
return table, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteWorkTable() {
|
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableName {
|
|
||||||
sConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rule to handle management of rules
|
|
||||||
type Rule struct {
|
|
||||||
nftRule *nftables.Rule
|
|
||||||
nftSet *nftables.Set
|
|
||||||
ruleID string
|
|
||||||
ip net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
|
||||||
func (r *Rule) GetRuleID() string {
|
|
||||||
return r.ruleID
|
|
||||||
}
|
|
||||||
24
client/firewall/port.go
Normal file
24
client/firewall/port.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
// PortProtocol is the protocol of the port
|
||||||
|
type PortProtocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PortProtocolTCP is the TCP protocol
|
||||||
|
PortProtocolTCP PortProtocol = "tcp"
|
||||||
|
|
||||||
|
// PortProtocolUDP is the UDP protocol
|
||||||
|
PortProtocolUDP PortProtocol = "udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Port of the address for firewall rule
|
||||||
|
type Port struct {
|
||||||
|
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||||
|
IsRange bool
|
||||||
|
|
||||||
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
|
Values []int
|
||||||
|
|
||||||
|
// Proto is the protocol of the port
|
||||||
|
Proto PortProtocol
|
||||||
|
}
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
|
|
||||||
var (
|
|
||||||
InsertRuleTestCases = []struct {
|
|
||||||
Name string
|
|
||||||
InputPair firewall.RouterPair
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
Name: "Insert Forwarding IPV4 Rule",
|
|
||||||
InputPair: firewall.RouterPair{
|
|
||||||
ID: "zxa",
|
|
||||||
Source: "100.100.100.1/32",
|
|
||||||
Destination: "100.100.200.0/24",
|
|
||||||
Masquerade: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
|
||||||
InputPair: firewall.RouterPair{
|
|
||||||
ID: "zxa",
|
|
||||||
Source: "100.100.100.1/32",
|
|
||||||
Destination: "100.100.200.0/24",
|
|
||||||
Masquerade: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RemoveRuleTestCases = []struct {
|
|
||||||
Name string
|
|
||||||
InputPair firewall.RouterPair
|
|
||||||
IpVersion string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
|
||||||
InputPair: firewall.RouterPair{
|
|
||||||
ID: "zxa",
|
|
||||||
Source: "100.100.100.1/32",
|
|
||||||
Destination: "100.100.200.0/24",
|
|
||||||
Masquerade: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package uspfilter
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.Reset()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
if m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.AllowNetbird()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type action string
|
|
||||||
|
|
||||||
const (
|
|
||||||
addRule action = "add"
|
|
||||||
deleteRule action = "delete"
|
|
||||||
firewallRuleName = "Netbird"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isFirewallRuleActive(firewallRuleName) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
|
||||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
if !isWindowsFirewallReachable() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if isFirewallRuleActive(firewallRuleName) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return manageFirewallRule(firewallRuleName,
|
|
||||||
addRule,
|
|
||||||
"dir=in",
|
|
||||||
"enable=yes",
|
|
||||||
"action=allow",
|
|
||||||
"profile=any",
|
|
||||||
"localip="+m.wgIface.Address().IP.String(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
|
||||||
|
|
||||||
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
|
||||||
if action == addRule {
|
|
||||||
args = append(args, extraArgs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("netsh", args...)
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func isWindowsFirewallReachable() bool {
|
|
||||||
args := []string{"advfirewall", "show", "allprofiles", "state"}
|
|
||||||
cmd := exec.Command("netsh", args...)
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
|
||||||
|
|
||||||
_, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
log.Infof("Windows firewall is not reachable, skipping default rule management. Using only user space rules. Error: %s", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isFirewallRuleActive(ruleName string) bool {
|
|
||||||
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
|
||||||
|
|
||||||
cmd := exec.Command("netsh", args...)
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
|
||||||
_, err := cmd.Output()
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rule to handle management of rules
|
|
||||||
type Rule struct {
|
|
||||||
id string
|
|
||||||
ip net.IP
|
|
||||||
ipLayer gopacket.LayerType
|
|
||||||
matchByIP bool
|
|
||||||
protoLayer gopacket.LayerType
|
|
||||||
direction firewall.RuleDirection
|
|
||||||
sPort uint16
|
|
||||||
dPort uint16
|
|
||||||
drop bool
|
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
|
||||||
func (r *Rule) GetRuleID() string {
|
|
||||||
return r.id
|
|
||||||
}
|
|
||||||
@@ -1,412 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
const layerTypeAll = 0
|
|
||||||
|
|
||||||
var (
|
|
||||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
|
||||||
)
|
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
SetFilter(iface.PacketFilter) error
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
|
||||||
type RuleSet map[string]Rule
|
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
|
||||||
type Manager struct {
|
|
||||||
outgoingRules map[string]RuleSet
|
|
||||||
incomingRules map[string]RuleSet
|
|
||||||
wgNetwork *net.IPNet
|
|
||||||
decoders sync.Pool
|
|
||||||
wgIface IFaceMapper
|
|
||||||
nativeFirewall firewall.Manager
|
|
||||||
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// decoder for packages
|
|
||||||
type decoder struct {
|
|
||||||
eth layers.Ethernet
|
|
||||||
ip4 layers.IPv4
|
|
||||||
ip6 layers.IPv6
|
|
||||||
tcp layers.TCP
|
|
||||||
udp layers.UDP
|
|
||||||
icmp4 layers.ICMPv4
|
|
||||||
icmp6 layers.ICMPv6
|
|
||||||
decoded []gopacket.LayerType
|
|
||||||
parser *gopacket.DecodingLayerParser
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
|
||||||
return create(iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
|
||||||
mgr, err := create(iface)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr.nativeFirewall = nativeFirewall
|
|
||||||
return mgr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
|
||||||
m := &Manager{
|
|
||||||
decoders: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
d := &decoder{
|
|
||||||
decoded: []gopacket.LayerType{},
|
|
||||||
}
|
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
|
||||||
)
|
|
||||||
d.parser.IgnoreUnsupported = true
|
|
||||||
return d
|
|
||||||
},
|
|
||||||
},
|
|
||||||
outgoingRules: make(map[string]RuleSet),
|
|
||||||
incomingRules: make(map[string]RuleSet),
|
|
||||||
wgIface: iface,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errRouteNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes a routing firewall rule
|
|
||||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errRouteNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
func (m *Manager) AddFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
r := Rule{
|
|
||||||
id: uuid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
|
||||||
matchByIP: true,
|
|
||||||
direction: direction,
|
|
||||||
drop: action == firewall.ActionDrop,
|
|
||||||
comment: comment,
|
|
||||||
}
|
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
|
||||||
r.ip = ipNormalized
|
|
||||||
}
|
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
|
||||||
r.matchByIP = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) == 1 {
|
|
||||||
r.sPort = uint16(sPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) == 1 {
|
|
||||||
r.dPort = uint16(dPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case firewall.ProtocolTCP:
|
|
||||||
r.protoLayer = layers.LayerTypeTCP
|
|
||||||
case firewall.ProtocolUDP:
|
|
||||||
r.protoLayer = layers.LayerTypeUDP
|
|
||||||
case firewall.ProtocolICMP:
|
|
||||||
r.protoLayer = layers.LayerTypeICMPv4
|
|
||||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
|
||||||
r.protoLayer = layers.LayerTypeICMPv6
|
|
||||||
}
|
|
||||||
case firewall.ProtocolALL:
|
|
||||||
r.protoLayer = layerTypeAll
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if direction == firewall.RuleDirectionIN {
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
|
||||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
return []firewall.Rule{&r}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
|
||||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.direction == firewall.RuleDirectionIN {
|
|
||||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
|
||||||
} else {
|
|
||||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
|
||||||
func (m *Manager) Flush() error { return nil }
|
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
|
||||||
return m.dropFilter(packetData, m.incomingRules, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropFilter implements same logic for booth direction of the traffic
|
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
|
||||||
defer m.decoders.Put(d)
|
|
||||||
|
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
|
||||||
log.Tracef("not enough levels in network packet")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
ipLayer := d.decoded[0]
|
|
||||||
|
|
||||||
switch ipLayer {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var ip net.IP
|
|
||||||
switch ipLayer {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip4.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip4.DstIP
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip6.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip6.DstIP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
|
|
||||||
// default policy is DROP ALL
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch payloadLayer {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
|
||||||
// we ignore rule.drop and call this hook
|
|
||||||
if rule.udpHook != nil {
|
|
||||||
return rule.udpHook(packetData), true
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
return rule.drop, true
|
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
|
||||||
m.wgNetwork = network
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
|
||||||
//
|
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
|
||||||
func (m *Manager) AddUDPPacketHook(
|
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
|
||||||
) string {
|
|
||||||
r := Rule{
|
|
||||||
id: uuid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
protoLayer: layers.LayerTypeUDP,
|
|
||||||
dPort: dPort,
|
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
|
||||||
direction: firewall.RuleDirectionOUT,
|
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
|
||||||
udpHook: hook,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if in {
|
|
||||||
r.direction = firewall.RuleDirectionIN
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
|
||||||
}
|
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
return r.id
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
|
||||||
for _, arr := range m.incomingRules {
|
|
||||||
for _, r := range arr {
|
|
||||||
if r.id == hookID {
|
|
||||||
rule := r
|
|
||||||
return m.DeleteRule(&rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, arr := range m.outgoingRules {
|
|
||||||
for _, r := range arr {
|
|
||||||
if r.id == hookID {
|
|
||||||
rule := r
|
|
||||||
return m.DeleteRule(&rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("hook with given id not found")
|
|
||||||
}
|
|
||||||
@@ -1,419 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IFaceMock struct {
|
|
||||||
SetFilterFunc func(iface.PacketFilter) error
|
|
||||||
AddressFunc func() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
|
||||||
if i.SetFilterFunc == nil {
|
|
||||||
return fmt.Errorf("not implemented")
|
|
||||||
}
|
|
||||||
return i.SetFilterFunc(iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *IFaceMock) Address() iface.WGAddress {
|
|
||||||
if i.AddressFunc == nil {
|
|
||||||
return iface.WGAddress{}
|
|
||||||
}
|
|
||||||
return i.AddressFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerCreate(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if m == nil {
|
|
||||||
t.Error("Manager is nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerAddFiltering(t *testing.T) {
|
|
||||||
isSetFilterCalled := false
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error {
|
|
||||||
isSetFilterCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []int{80}}
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule == nil {
|
|
||||||
t.Error("Rule is nil")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isSetFilterCalled {
|
|
||||||
t.Error("SetFilter was not called")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerDeleteRule(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []int{80}}
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip = net.ParseIP("192.168.1.1")
|
|
||||||
proto = fw.ProtocolTCP
|
|
||||||
port = &fw.Port{Values: []int{80}}
|
|
||||||
direction = fw.RuleDirectionIN
|
|
||||||
action = fw.ActionDrop
|
|
||||||
comment = "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule {
|
|
||||||
err = m.DeleteRule(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule2 {
|
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule2 {
|
|
||||||
err = m.DeleteRule(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule2 {
|
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddUDPPacketHook(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in bool
|
|
||||||
expDir fw.RuleDirection
|
|
||||||
ip net.IP
|
|
||||||
dPort uint16
|
|
||||||
hook func([]byte) bool
|
|
||||||
expectedID string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Test Outgoing UDP Packet Hook",
|
|
||||||
in: false,
|
|
||||||
expDir: fw.RuleDirectionOUT,
|
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
|
||||||
dPort: 8000,
|
|
||||||
hook: func([]byte) bool { return true },
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Test Incoming UDP Packet Hook",
|
|
||||||
in: true,
|
|
||||||
expDir: fw.RuleDirectionIN,
|
|
||||||
ip: net.IPv6loopback,
|
|
||||||
dPort: 9000,
|
|
||||||
hook: func([]byte) bool { return false },
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
manager := &Manager{
|
|
||||||
incomingRules: map[string]RuleSet{},
|
|
||||||
outgoingRules: map[string]RuleSet{},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
|
||||||
|
|
||||||
var addedRule Rule
|
|
||||||
if tt.in {
|
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
|
||||||
addedRule = rule
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(manager.outgoingRules) != 1 {
|
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
|
||||||
addedRule = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.dPort != addedRule.dPort {
|
|
||||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
|
||||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.expDir != addedRule.direction {
|
|
||||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if addedRule.udpHook == nil {
|
|
||||||
t.Errorf("expected udpHook to be set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []int{80}}
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.Reset()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
|
||||||
t.Errorf("rules is not empty")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotMatchByIP(t *testing.T) {
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
|
||||||
proto := fw.ProtocolUDP
|
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionAccept
|
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
TTL: 64,
|
|
||||||
Version: 4,
|
|
||||||
SrcIP: net.ParseIP("100.10.0.1"),
|
|
||||||
DstIP: net.ParseIP("100.10.0.100"),
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
udp := &layers.UDP{
|
|
||||||
SrcPort: 51334,
|
|
||||||
DstPort: 53,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
|
||||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload := gopacket.Payload([]byte("test"))
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
opts := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
if err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload); err != nil {
|
|
||||||
t.Errorf("failed to serialize packet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
|
||||||
t.Errorf("expected packet to be accepted")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = m.Reset(); err != nil {
|
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRemovePacketHook tests the functionality of the RemovePacketHook method
|
|
||||||
func TestRemovePacketHook(t *testing.T) {
|
|
||||||
// creating mock iface
|
|
||||||
iface := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
// creating manager instance
|
|
||||||
manager, err := Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add a UDP packet hook
|
|
||||||
hookFunc := func(data []byte) bool { return true }
|
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
|
||||||
found := false
|
|
||||||
for _, arr := range manager.outgoingRules {
|
|
||||||
for _, rule := range arr {
|
|
||||||
if rule.id == hookID {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("The hook was not added properly.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now remove the packet hook
|
|
||||||
err = manager.RemovePacketHook(hookID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to remove hook: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
|
||||||
for _, arr := range manager.outgoingRules {
|
|
||||||
for _, rule := range arr {
|
|
||||||
if rule.id == hookID {
|
|
||||||
t.Fatalf("The hook was not removed properly.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
|
||||||
// just check on the local interface
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
manager, err := Create(ifaceMock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := manager.Reset(); err != nil {
|
|
||||||
t.Errorf("clear the manager state: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
|
||||||
start := time.Now()
|
|
||||||
for i := 0; i < testMax; i++ {
|
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
}
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -166,9 +166,10 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
|||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::AddValueEx "path" "$INSTDIR"
|
EnVar::AddValueEx "path" "$INSTDIR"
|
||||||
|
|
||||||
SetShellVarContext all
|
SetShellVarContext current
|
||||||
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||||
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||||
|
SetShellVarContext all
|
||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
Section -Post
|
Section -Post
|
||||||
@@ -193,12 +194,12 @@ Sleep 3000
|
|||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
SetShellVarContext all
|
SetShellVarContext current
|
||||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||||
|
SetShellVarContext all
|
||||||
|
|
||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
@@ -208,7 +209,8 @@ SectionEnd
|
|||||||
|
|
||||||
|
|
||||||
Function LaunchLink
|
Function LaunchLink
|
||||||
SetShellVarContext all
|
SetShellVarContext current
|
||||||
SetOutPath $INSTDIR
|
SetOutPath $INSTDIR
|
||||||
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
||||||
|
SetShellVarContext all
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|||||||
@@ -1,454 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/md5"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
|
||||||
type Manager interface {
|
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
|
||||||
type DefaultManager struct {
|
|
||||||
firewall firewall.Manager
|
|
||||||
ipsetCounter int
|
|
||||||
rulesPairs map[string][]firewall.Rule
|
|
||||||
mutex sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
|
||||||
return &DefaultManager{
|
|
||||||
firewall: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
|
||||||
//
|
|
||||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
|
||||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|
||||||
d.mutex.Lock()
|
|
||||||
defer d.mutex.Unlock()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
total := 0
|
|
||||||
for _, pairs := range d.rulesPairs {
|
|
||||||
total += len(pairs)
|
|
||||||
}
|
|
||||||
log.Infof(
|
|
||||||
"ACL rules processed in: %v, total rules count: %d",
|
|
||||||
time.Since(start), total)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if d.firewall == nil {
|
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := d.firewall.Flush(); err != nil {
|
|
||||||
log.Error("failed to flush firewall rules: ", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled)
|
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// if TCP protocol rules not squashed and SSH enabled
|
|
||||||
// we add default firewall rule which accepts connection to any peer
|
|
||||||
// in the network by SSH (TCP 22 port).
|
|
||||||
if enableSSH {
|
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
|
||||||
// we have old version of management without rules handling, we should allow all traffic
|
|
||||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
|
||||||
log.Warn("this peer is connected to a NetBird Management service with an older version. Allowing all traffic from connected peers")
|
|
||||||
rules = append(rules,
|
|
||||||
&mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
&mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
|
||||||
ipsetByRuleSelectors := make(map[string]string)
|
|
||||||
|
|
||||||
for _, r := range rules {
|
|
||||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
|
||||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
|
||||||
selector := d.getRuleGroupingSelector(r)
|
|
||||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
|
||||||
if !ok {
|
|
||||||
d.ipsetCounter++
|
|
||||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
|
||||||
ipsetByRuleSelectors[selector] = ipsetName
|
|
||||||
}
|
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
|
||||||
d.rollBack(newRulePairs)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if len(rules) > 0 {
|
|
||||||
d.rulesPairs[pairID] = rulePair
|
|
||||||
newRulePairs[pairID] = rulePair
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for pairID, rules := range d.rulesPairs {
|
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete firewall rule: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(d.rulesPairs, pairID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d.rulesPairs = newRulePairs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
|
||||||
r *mgmProto.FirewallRule,
|
|
||||||
ipsetName string,
|
|
||||||
) (string, []firewall.Rule, error) {
|
|
||||||
ip := net.ParseIP(r.PeerIP)
|
|
||||||
if ip == nil {
|
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
action, err := convertFirewallAction(r.Action)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var port *firewall.Port
|
|
||||||
if r.Port != "" {
|
|
||||||
value, err := strconv.Atoi(r.Port)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
|
||||||
}
|
|
||||||
port = &firewall.Port{
|
|
||||||
Values: []int{value},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
|
||||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
|
||||||
return ruleID, rulesPair, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []firewall.Rule
|
|
||||||
switch r.Direction {
|
|
||||||
case mgmProto.FirewallRule_IN:
|
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
|
||||||
case mgmProto.FirewallRule_OUT:
|
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
|
||||||
default:
|
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ruleID, rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
|
||||||
ip net.IP,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
var rules []firewall.Rule
|
|
||||||
rule, err := d.firewall.AddFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, rule...)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, err = d.firewall.AddFiltering(
|
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, rule...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
|
||||||
ip net.IP,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
var rules []firewall.Rule
|
|
||||||
rule, err := d.firewall.AddFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, rule...)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, err = d.firewall.AddFiltering(
|
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, rule...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
|
||||||
func (d *DefaultManager) getRuleID(
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
direction int,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
comment string,
|
|
||||||
) string {
|
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
|
||||||
if port != nil {
|
|
||||||
idStr += port.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
|
||||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
|
||||||
//
|
|
||||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
|
||||||
// but other has port definitions or has drop policy.
|
|
||||||
func (d *DefaultManager) squashAcceptRules(
|
|
||||||
networkMap *mgmProto.NetworkMap,
|
|
||||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
|
||||||
totalIPs := 0
|
|
||||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
|
||||||
for range p.AllowedIps {
|
|
||||||
totalIPs++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
|
|
||||||
|
|
||||||
in := protoMatch{}
|
|
||||||
out := protoMatch{}
|
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
|
||||||
// But we zeroed the IP's for protocol if:
|
|
||||||
// 1. Any of the rule has DROP action type.
|
|
||||||
// 2. Any of rule contains Port.
|
|
||||||
//
|
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
|
||||||
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
|
|
||||||
if drop {
|
|
||||||
protocols[r.Protocol] = map[string]int{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
|
||||||
protocols[r.Protocol] = map[string]int{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// special case, when we receive this all network IP address
|
|
||||||
// it means that rules for that protocol was already optimized on the
|
|
||||||
// management side
|
|
||||||
if r.PeerIP == "0.0.0.0" {
|
|
||||||
squashedRules = append(squashedRules, r)
|
|
||||||
squashedProtocols[r.Protocol] = struct{}{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipset := protocols[r.Protocol]
|
|
||||||
|
|
||||||
if _, ok := ipset[r.PeerIP]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipset[r.PeerIP] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
// calculate squash for different directions
|
|
||||||
if r.Direction == mgmProto.FirewallRule_IN {
|
|
||||||
addRuleToCalculationMap(i, r, in)
|
|
||||||
} else {
|
|
||||||
addRuleToCalculationMap(i, r, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// order of squashing by protocol is important
|
|
||||||
// only for their first element ALL, it must be done first
|
|
||||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
|
||||||
mgmProto.FirewallRule_ALL,
|
|
||||||
mgmProto.FirewallRule_ICMP,
|
|
||||||
mgmProto.FirewallRule_TCP,
|
|
||||||
mgmProto.FirewallRule_UDP,
|
|
||||||
}
|
|
||||||
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
|
||||||
for _, protocol := range protocolOrders {
|
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
|
||||||
// don't squash if :
|
|
||||||
// 1. Rules not cover all peers in the network
|
|
||||||
// 2. Rules cover only one peer in the network.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
|
||||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: direction,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: protocol,
|
|
||||||
})
|
|
||||||
squashedProtocols[protocol] = struct{}{}
|
|
||||||
|
|
||||||
if protocol == mgmProto.FirewallRule_ALL {
|
|
||||||
// if we have ALL traffic type squashed rule
|
|
||||||
// it allows all other type of traffic, so we can stop processing
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
squash(in, mgmProto.FirewallRule_IN)
|
|
||||||
squash(out, mgmProto.FirewallRule_OUT)
|
|
||||||
|
|
||||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
|
||||||
return squashedRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(squashedRules) == 0 {
|
|
||||||
return networkMap.FirewallRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []*mgmProto.FirewallRule
|
|
||||||
// filter out rules which was squashed from final list
|
|
||||||
// if we also have other not squashed rules.
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
|
||||||
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rules = append(rules, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, squashedRules...), squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
|
||||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
|
||||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
|
||||||
log.Debugf("rollback ACL to previous state")
|
|
||||||
for _, rules := range newRulePairs {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
|
||||||
switch protocol {
|
|
||||||
case mgmProto.FirewallRule_TCP:
|
|
||||||
return firewall.ProtocolTCP, nil
|
|
||||||
case mgmProto.FirewallRule_UDP:
|
|
||||||
return firewall.ProtocolUDP, nil
|
|
||||||
case mgmProto.FirewallRule_ICMP:
|
|
||||||
return firewall.ProtocolICMP, nil
|
|
||||||
case mgmProto.FirewallRule_ALL:
|
|
||||||
return firewall.ProtocolALL, nil
|
|
||||||
default:
|
|
||||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) bool {
|
|
||||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
|
||||||
switch action {
|
|
||||||
case mgmProto.FirewallRule_ACCEPT:
|
|
||||||
return firewall.ActionAccept, nil
|
|
||||||
case mgmProto.FirewallRule_DROP:
|
|
||||||
return firewall.ActionDrop, nil
|
|
||||||
default:
|
|
||||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,364 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
Port: "80",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_DROP,
|
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
|
||||||
Port: "53",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
|
||||||
IP: ip,
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create firewall: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func(fw manager.Manager) {
|
|
||||||
_ = fw.Reset()
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
|
||||||
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("add extra rules", func(t *testing.T) {
|
|
||||||
existedPairs := map[string]struct{}{}
|
|
||||||
for id := range acl.rulesPairs {
|
|
||||||
existedPairs[id] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove first rule
|
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[1:]
|
|
||||||
networkMap.FirewallRules = append(
|
|
||||||
networkMap.FirewallRules,
|
|
||||||
&mgmProto.FirewallRule{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_DROP,
|
|
||||||
Protocol: mgmProto.FirewallRule_ICMP,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
|
||||||
t.Errorf("firewall rules not applied")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that old rule was removed
|
|
||||||
previousCount := 0
|
|
||||||
for id := range acl.rulesPairs {
|
|
||||||
if _, ok := existedPairs[id]; ok {
|
|
||||||
previousCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if previousCount != 1 {
|
|
||||||
t.Errorf("old rule was not removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
|
||||||
if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
|
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
|
||||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
if len(rules) != 2 {
|
|
||||||
t.Errorf("rules should contain 2, got: %v", rules)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r := rules[0]
|
|
||||||
switch {
|
|
||||||
case r.PeerIP != "0.0.0.0":
|
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
|
||||||
return
|
|
||||||
case r.Direction != mgmProto.FirewallRule_IN:
|
|
||||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r = rules[1]
|
|
||||||
switch {
|
|
||||||
case r.PeerIP != "0.0.0.0":
|
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
|
||||||
return
|
|
||||||
case r.Direction != mgmProto.FirewallRule_OUT:
|
|
||||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
|
||||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
PeerConfig: &mgmProto.PeerConfig{
|
|
||||||
SshConfig: &mgmProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
|
||||||
IP: ip,
|
|
||||||
Network: network,
|
|
||||||
}).AnyTimes()
|
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create firewall: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func(fw manager.Manager) {
|
|
||||||
_ = fw.Reset()
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
|
||||||
|
|
||||||
if len(acl.rulesPairs) != 4 {
|
|
||||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
## Mocks
|
|
||||||
|
|
||||||
To generate (or refresh) mocks from acl package please install [mockgen](https://github.com/golang/mock).
|
|
||||||
Run this command from the `./client/internal/acl` folder to update iface mapper interface mock:
|
|
||||||
```bash
|
|
||||||
mockgen -destination mocks/iface_mapper.go -package mocks . IFaceMapper
|
|
||||||
```
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/netbirdio/netbird/client/internal/acl (interfaces: IFaceMapper)
|
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
|
||||||
package mocks
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
iface "github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
|
||||||
type MockIFaceMapper struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockIFaceMapperMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockIFaceMapperMockRecorder is the mock recorder for MockIFaceMapper.
|
|
||||||
type MockIFaceMapperMockRecorder struct {
|
|
||||||
mock *MockIFaceMapper
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockIFaceMapper creates a new mock instance.
|
|
||||||
func NewMockIFaceMapper(ctrl *gomock.Controller) *MockIFaceMapper {
|
|
||||||
mock := &MockIFaceMapper{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockIFaceMapperMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|
||||||
func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address mocks base method.
|
|
||||||
func (m *MockIFaceMapper) Address() iface.WGAddress {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Address")
|
|
||||||
ret0, _ := ret[0].(iface.WGAddress)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address indicates an expected call of Address.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) Address() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Address", reflect.TypeOf((*MockIFaceMapper)(nil).Address))
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserspaceBind mocks base method.
|
|
||||||
func (m *MockIFaceMapper) IsUserspaceBind() bool {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "IsUserspaceBind")
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserspaceBind indicates an expected call of IsUserspaceBind.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) IsUserspaceBind() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsUserspaceBind", reflect.TypeOf((*MockIFaceMapper)(nil).IsUserspaceBind))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name mocks base method.
|
|
||||||
func (m *MockIFaceMapper) Name() string {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Name")
|
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name indicates an expected call of Name.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFilter mocks base method.
|
|
||||||
func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "SetFilter", arg0)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFilter indicates an expected call of SetFilter.
|
|
||||||
func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
|
||||||
}
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HostedGrantType grant type for device flow on Hosted
|
|
||||||
const (
|
|
||||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
|
||||||
|
|
||||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
|
||||||
// for the Device Authorization Flow.
|
|
||||||
type DeviceAuthorizationFlow struct {
|
|
||||||
providerConfig internal.DeviceAuthProviderConfig
|
|
||||||
|
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
|
||||||
type RequestDeviceCodePayload struct {
|
|
||||||
Audience string `json:"audience"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
Scope string `json:"scope"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRequestPayload used for requesting the auth0 token
|
|
||||||
type TokenRequestPayload struct {
|
|
||||||
GrantType string `json:"grant_type"`
|
|
||||||
DeviceCode string `json:"device_code,omitempty"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRequestResponse used for parsing Hosted token's response
|
|
||||||
type TokenRequestResponse struct {
|
|
||||||
Error string `json:"error"`
|
|
||||||
ErrorDescription string `json:"error_description"`
|
|
||||||
TokenInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
|
||||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DeviceAuthorizationFlow{
|
|
||||||
providerConfig: config,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientID returns the provider client id
|
|
||||||
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
|
||||||
return d.providerConfig.ClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
|
||||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("client_id", d.providerConfig.ClientID)
|
|
||||||
form.Add("audience", d.providerConfig.Audience)
|
|
||||||
form.Add("scope", d.providerConfig.Scope)
|
|
||||||
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
|
|
||||||
strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
res, err := d.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer res.Body.Close()
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceCode := AuthFlowInfo{}
|
|
||||||
err = json.Unmarshal(body, &deviceCode)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
|
||||||
if deviceCode.VerificationURIComplete == "" {
|
|
||||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceCode, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("client_id", d.providerConfig.ClientID)
|
|
||||||
form.Add("grant_type", HostedGrantType)
|
|
||||||
form.Add("device_code", info.DeviceCode)
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
res, err := d.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode > 499 {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenResponse := TokenRequestResponse{}
|
|
||||||
err = json.Unmarshal(body, &tokenResponse)
|
|
||||||
if err != nil {
|
|
||||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
|
||||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
|
||||||
ticker := time.NewTicker(interval)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return TokenInfo{}, ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
|
|
||||||
tokenResponse, err := d.requestToken(info)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenResponse.Error != "" {
|
|
||||||
if tokenResponse.Error == "authorization_pending" {
|
|
||||||
continue
|
|
||||||
} else if tokenResponse.Error == "slow_down" {
|
|
||||||
interval += (3 * time.Second)
|
|
||||||
ticker.Reset(interval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenInfo := TokenInfo{
|
|
||||||
AccessToken: tokenResponse.AccessToken,
|
|
||||||
TokenType: tokenResponse.TokenType,
|
|
||||||
RefreshToken: tokenResponse.RefreshToken,
|
|
||||||
IDToken: tokenResponse.IDToken,
|
|
||||||
ExpiresIn: tokenResponse.ExpiresIn,
|
|
||||||
UseIDToken: d.providerConfig.UseIDToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
|
|
||||||
type OAuthFlow interface {
|
|
||||||
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
|
|
||||||
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
|
|
||||||
GetClientID(ctx context.Context) string
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPClient http client interface for API calls
|
|
||||||
type HTTPClient interface {
|
|
||||||
Do(req *http.Request) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
|
||||||
type AuthFlowInfo struct {
|
|
||||||
DeviceCode string `json:"device_code"`
|
|
||||||
UserCode string `json:"user_code"`
|
|
||||||
VerificationURI string `json:"verification_uri"`
|
|
||||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
Interval int `json:"interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Claims used when validating the access token
|
|
||||||
type Claims struct {
|
|
||||||
Audience interface{} `json:"aud"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenInfo holds information of issued access token
|
|
||||||
type TokenInfo struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
UseIDToken bool `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
|
||||||
func (t TokenInfo) GetTokenToUse() string {
|
|
||||||
if t.UseIDToken {
|
|
||||||
return t.IDToken
|
|
||||||
}
|
|
||||||
return t.AccessToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
|
||||||
//
|
|
||||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
|
||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
|
||||||
//
|
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
|
||||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
|
||||||
if err != nil {
|
|
||||||
// fallback to device code flow
|
|
||||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
|
||||||
log.Debug("falling back to device code flow")
|
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
|
||||||
}
|
|
||||||
return pkceFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
|
||||||
}
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
switch s, ok := gstatus.FromError(err); {
|
|
||||||
case ok && s.Code() == codes.NotFound:
|
|
||||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
|
||||||
"Please proceed with setting up this device using setup keys " +
|
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
|
||||||
case ok && s.Code() == codes.Unimplemented:
|
|
||||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
|
||||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
|
||||||
}
|
|
||||||
@@ -1,255 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"html/template"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
|
||||||
|
|
||||||
const (
|
|
||||||
queryState = "state"
|
|
||||||
queryCode = "code"
|
|
||||||
queryError = "error"
|
|
||||||
queryErrorDesc = "error_description"
|
|
||||||
defaultPKCETimeoutSeconds = 300
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
|
||||||
// the Authorization Code Flow with PKCE.
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
providerConfig internal.PKCEAuthProviderConfig
|
|
||||||
state string
|
|
||||||
codeVerifier string
|
|
||||||
oAuthConfig *oauth2.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
|
||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
|
||||||
var availableRedirectURL string
|
|
||||||
|
|
||||||
// find the first available redirect URL
|
|
||||||
for _, redirectURL := range config.RedirectURLs {
|
|
||||||
if !isRedirectURLPortUsed(redirectURL) {
|
|
||||||
availableRedirectURL = redirectURL
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if availableRedirectURL == "" {
|
|
||||||
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := &oauth2.Config{
|
|
||||||
ClientID: config.ClientID,
|
|
||||||
ClientSecret: config.ClientSecret,
|
|
||||||
Endpoint: oauth2.Endpoint{
|
|
||||||
AuthURL: config.AuthorizationEndpoint,
|
|
||||||
TokenURL: config.TokenEndpoint,
|
|
||||||
},
|
|
||||||
RedirectURL: availableRedirectURL,
|
|
||||||
Scopes: strings.Split(config.Scope, " "),
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PKCEAuthorizationFlow{
|
|
||||||
providerConfig: config,
|
|
||||||
oAuthConfig: cfg,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientID returns the provider client id
|
|
||||||
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
|
||||||
return p.providerConfig.ClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestAuthInfo requests a authorization code login flow information.
|
|
||||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
|
||||||
state, err := randomBytesInHex(24)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
|
||||||
}
|
|
||||||
p.state = state
|
|
||||||
|
|
||||||
codeVerifier, err := randomBytesInHex(64)
|
|
||||||
if err != nil {
|
|
||||||
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
|
|
||||||
}
|
|
||||||
p.codeVerifier = codeVerifier
|
|
||||||
|
|
||||||
codeChallenge := createCodeChallenge(codeVerifier)
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(
|
|
||||||
state,
|
|
||||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
||||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
|
||||||
)
|
|
||||||
|
|
||||||
return AuthFlowInfo{
|
|
||||||
VerificationURIComplete: authURL,
|
|
||||||
ExpiresIn: defaultPKCETimeoutSeconds,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
|
||||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
|
||||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
|
||||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
|
||||||
tokenChan := make(chan *oauth2.Token, 1)
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
|
|
||||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
|
||||||
defer func() {
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
|
||||||
log.Errorf("failed to close the server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go p.startServer(server, tokenChan, errChan)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return TokenInfo{}, ctx.Err()
|
|
||||||
case token := <-tokenChan:
|
|
||||||
return p.parseOAuthToken(token)
|
|
||||||
case err := <-errChan:
|
|
||||||
return TokenInfo{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
token, err := p.handleRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
renderPKCEFlowTmpl(w, err)
|
|
||||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
renderPKCEFlowTmpl(w, nil)
|
|
||||||
tokenChan <- token
|
|
||||||
})
|
|
||||||
|
|
||||||
server.Handler = mux
|
|
||||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
|
||||||
query := req.URL.Query()
|
|
||||||
|
|
||||||
if authError := query.Get(queryError); authError != "" {
|
|
||||||
authErrorDesc := query.Get(queryErrorDesc)
|
|
||||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prevent timing attacks on the state
|
|
||||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
|
||||||
return nil, fmt.Errorf("invalid state")
|
|
||||||
}
|
|
||||||
|
|
||||||
code := query.Get(queryCode)
|
|
||||||
if code == "" {
|
|
||||||
return nil, fmt.Errorf("missing code")
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.oAuthConfig.Exchange(
|
|
||||||
req.Context(),
|
|
||||||
code,
|
|
||||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
|
||||||
tokenInfo := TokenInfo{
|
|
||||||
AccessToken: token.AccessToken,
|
|
||||||
RefreshToken: token.RefreshToken,
|
|
||||||
TokenType: token.TokenType,
|
|
||||||
ExpiresIn: token.Expiry.Second(),
|
|
||||||
UseIDToken: p.providerConfig.UseIDToken,
|
|
||||||
}
|
|
||||||
if idToken, ok := token.Extra("id_token").(string); ok {
|
|
||||||
tokenInfo.IDToken = idToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// if a provider doesn't support an audience, use the Client ID for token verification
|
|
||||||
audience := p.providerConfig.Audience
|
|
||||||
if audience == "" {
|
|
||||||
audience = p.providerConfig.ClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createCodeChallenge(codeVerifier string) string {
|
|
||||||
sha2 := sha256.Sum256([]byte(codeVerifier))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
|
||||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
|
||||||
parsedURL, err := url.Parse(redirectURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to parse redirect URL: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("error while closing the connection: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
|
||||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data := make(map[string]string)
|
|
||||||
if authError != nil {
|
|
||||||
data["Error"] = authError.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tmpl.Execute(w, data); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func randomBytesInHex(count int) (string, error) {
|
|
||||||
buf := make([]byte, count)
|
|
||||||
_, err := io.ReadFull(rand.Reader, buf)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(buf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isValidAccessToken is a simple validation of the access token
|
|
||||||
func isValidAccessToken(token string, audience string) error {
|
|
||||||
if token == "" {
|
|
||||||
return fmt.Errorf("token received is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
encodedClaims := strings.Split(token, ".")[1]
|
|
||||||
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := Claims{}
|
|
||||||
err = json.Unmarshal(claimsString, &claims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Audience == nil {
|
|
||||||
return fmt.Errorf("required token field audience is absent")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Audience claim of JWT can be a string or an array of strings
|
|
||||||
switch aud := claims.Audience.(type) {
|
|
||||||
case string:
|
|
||||||
if aud == audience {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, audItem := range aud {
|
|
||||||
if audStr, ok := audItem.(string); ok && audStr == audience {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("invalid JWT token audience field")
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,19 +12,16 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// managementLegacyPortString is the port that was used before by the Management gRPC server.
|
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||||
// It is used for backward compatibility now.
|
// It is used for backward compatibility now.
|
||||||
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
|
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
|
||||||
managementLegacyPortString = "33073"
|
ManagementLegacyPort = 33073
|
||||||
// DefaultManagementURL points to the NetBird's cloud management endpoint
|
// DefaultManagementURL points to the NetBird's cloud management endpoint
|
||||||
DefaultManagementURL = "https://api.netbird.io:443"
|
DefaultManagementURL = "https://api.wiretrustee.com:443"
|
||||||
// oldDefaultManagementURL points to the NetBird's old cloud management endpoint
|
|
||||||
oldDefaultManagementURL = "https://api.wiretrustee.com:443"
|
|
||||||
// DefaultAdminURL points to NetBird's cloud management console
|
// DefaultAdminURL points to NetBird's cloud management console
|
||||||
DefaultAdminURL = "https://app.netbird.io:443"
|
DefaultAdminURL = "https://app.netbird.io:443"
|
||||||
)
|
)
|
||||||
@@ -41,9 +37,6 @@ type ConfigInput struct {
|
|||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
RosenpassEnabled *bool
|
|
||||||
InterfaceName *string
|
|
||||||
WireguardPort *int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@@ -57,11 +50,10 @@ type Config struct {
|
|||||||
WgPort int
|
WgPort int
|
||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
RosenpassEnabled bool
|
|
||||||
// SSHKey is a private SSH key in a PEM format
|
// SSHKey is a private SSH key in a PEM format
|
||||||
SSHKey string
|
SSHKey string
|
||||||
|
|
||||||
// ExternalIP mappings, if different from the host interface IP
|
// ExternalIP mappings, if different than the host interface IP
|
||||||
//
|
//
|
||||||
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
|
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
|
||||||
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
|
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
|
||||||
@@ -144,10 +136,11 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
SSHKey: string(pem),
|
SSHKey: string(pem),
|
||||||
PrivateKey: wgKey,
|
PrivateKey: wgKey,
|
||||||
|
WgIface: iface.WgInterfaceDefault,
|
||||||
|
WgPort: iface.DefaultWgPort,
|
||||||
IFaceBlackList: []string{},
|
IFaceBlackList: []string{},
|
||||||
DisableIPv6Discovery: false,
|
DisableIPv6Discovery: false,
|
||||||
NATExternalIPs: input.NATExternalIPs,
|
NATExternalIPs: input.NATExternalIPs,
|
||||||
@@ -168,24 +161,10 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config.ManagementURL = URL
|
config.ManagementURL = URL
|
||||||
}
|
}
|
||||||
|
|
||||||
config.WgPort = iface.DefaultWgPort
|
|
||||||
if input.WireguardPort != nil {
|
|
||||||
config.WgPort = *input.WireguardPort
|
|
||||||
}
|
|
||||||
|
|
||||||
config.WgIface = iface.WgInterfaceDefault
|
|
||||||
if input.InterfaceName != nil {
|
|
||||||
config.WgIface = *input.InterfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.PreSharedKey != nil {
|
if input.PreSharedKey != nil {
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.RosenpassEnabled != nil {
|
|
||||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -236,7 +215,8 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||||
log.Infof("new pre-shared key provided, replacing old key")
|
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
||||||
|
*input.PreSharedKey, config.PreSharedKey)
|
||||||
config.PreSharedKey = *input.PreSharedKey
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
@@ -254,17 +234,6 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
config.WgPort = iface.DefaultWgPort
|
config.WgPort = iface.DefaultWgPort
|
||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.WireguardPort != nil {
|
|
||||||
config.WgPort = *input.WireguardPort
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.InterfaceName != nil {
|
|
||||||
config.WgIface = *input.InterfaceName
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
||||||
config.NATExternalIPs = input.NATExternalIPs
|
config.NATExternalIPs = input.NATExternalIPs
|
||||||
refresh = true
|
refresh = true
|
||||||
@@ -275,11 +244,6 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.RosenpassEnabled != nil {
|
|
||||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
|
||||||
refresh = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if refresh {
|
if refresh {
|
||||||
// since we have new management URL, we need to update config file
|
// since we have new management URL, we need to update config file
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||||
@@ -307,9 +271,9 @@ func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
|||||||
if parsedMgmtURL.Port() == "" {
|
if parsedMgmtURL.Port() == "" {
|
||||||
switch parsedMgmtURL.Scheme {
|
switch parsedMgmtURL.Scheme {
|
||||||
case "https":
|
case "https":
|
||||||
parsedMgmtURL.Host += ":443"
|
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
|
||||||
case "http":
|
case "http":
|
||||||
parsedMgmtURL.Host += ":80"
|
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
|
||||||
default:
|
default:
|
||||||
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
||||||
}
|
}
|
||||||
@@ -339,86 +303,3 @@ func configFileIsExists(path string) bool {
|
|||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return !os.IsNotExist(err)
|
return !os.IsNotExist(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
|
||||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
|
||||||
// The check is performed only for the NetBird's managed version.
|
|
||||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
|
||||||
|
|
||||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
|
||||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
|
||||||
// only do the check for the NetBird's managed version
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if config.ManagementURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !mgmTlsEnabled {
|
|
||||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
|
||||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
|
||||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
|
||||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
|
||||||
config.ManagementURL.String(), newURL.String())
|
|
||||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
|
||||||
return config, err
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
|
||||||
return config, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = client.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// gRPC check
|
|
||||||
_, err = client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// everything is alright => update the config
|
|
||||||
newConfig, err := UpdateConfig(ConfigInput{
|
|
||||||
ManagementURL: newURL.String(),
|
|
||||||
ConfigPath: configPath,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
|
||||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
|
||||||
}
|
|
||||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
|
||||||
|
|
||||||
return newConfig, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,16 +1,13 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetConfig(t *testing.T) {
|
func TestGetConfig(t *testing.T) {
|
||||||
@@ -26,6 +23,9 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
||||||
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
managementURL := "https://test.management.url:33071"
|
managementURL := "https://test.management.url:33071"
|
||||||
adminURL := "https://app.admin.url:443"
|
adminURL := "https://app.admin.url:443"
|
||||||
path := filepath.Join(t.TempDir(), "config.json")
|
path := filepath.Join(t.TempDir(), "config.json")
|
||||||
@@ -122,60 +122,3 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateOldManagementURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
previousManagementURL string
|
|
||||||
expectedManagementURL string
|
|
||||||
fileShouldNotChange bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Update old management URL with legacy port",
|
|
||||||
previousManagementURL: "https://api.wiretrustee.com:33073",
|
|
||||||
expectedManagementURL: DefaultManagementURL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Update old management URL",
|
|
||||||
previousManagementURL: oldDefaultManagementURL,
|
|
||||||
expectedManagementURL: DefaultManagementURL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No update needed when management URL is up to date",
|
|
||||||
previousManagementURL: DefaultManagementURL,
|
|
||||||
expectedManagementURL: DefaultManagementURL,
|
|
||||||
fileShouldNotChange: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No update needed when not using cloud management",
|
|
||||||
previousManagementURL: "https://netbird.example.com:33073",
|
|
||||||
expectedManagementURL: "https://netbird.example.com:33073",
|
|
||||||
fileShouldNotChange: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
configPath := filepath.Join(tempDir, "config.json")
|
|
||||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
|
||||||
ManagementURL: tt.previousManagementURL,
|
|
||||||
ConfigPath: configPath,
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "failed to create testing config")
|
|
||||||
previousStats, err := os.Stat(configPath)
|
|
||||||
require.NoError(t, err, "failed to create testing config stats")
|
|
||||||
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
|
|
||||||
require.NoError(t, err, "got error when updating old management url")
|
|
||||||
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
|
|
||||||
newStats, err := os.Stat(configPath)
|
|
||||||
require.NoError(t, err, "failed to create testing config stats")
|
|
||||||
switch tt.fileShouldNotChange {
|
|
||||||
case true:
|
|
||||||
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
|
|
||||||
case false:
|
|
||||||
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -22,39 +20,10 @@ import (
|
|||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
|
||||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunClientMobile with main logic on mobile system
|
|
||||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
|
||||||
// in case of non Android os these variables will be nil
|
|
||||||
mobileDependency := MobileDependency{
|
|
||||||
TunAdapter: tunAdapter,
|
|
||||||
IFaceDiscover: iFaceDiscover,
|
|
||||||
NetworkChangeListener: networkChangeListener,
|
|
||||||
HostDNSAddresses: dnsAddresses,
|
|
||||||
DnsReadyListener: dnsReadyListener,
|
|
||||||
}
|
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager) error {
|
|
||||||
mobileDependency := MobileDependency{
|
|
||||||
FileDescriptor: fileDescriptor,
|
|
||||||
NetworkChangeListener: networkChangeListener,
|
|
||||||
DnsManager: dnsManager,
|
|
||||||
}
|
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
|
||||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -108,7 +77,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
|
||||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||||
@@ -181,7 +150,13 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return wrapErr(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
@@ -191,6 +166,8 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
|
statusRecorder.ClientStart()
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
statusRecorder.ClientTeardown()
|
statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
@@ -211,7 +188,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
statusRecorder.ClientStart()
|
|
||||||
err = backoff.Retry(operation, backOff)
|
err = backoff.Retry(operation, backOff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||||
@@ -235,7 +211,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
SSHKey: []byte(config.SSHKey),
|
SSHKey: []byte(config.SSHKey),
|
||||||
NATExternalIPs: config.NATExternalIPs,
|
NATExternalIPs: config.NATExternalIPs,
|
||||||
CustomDNSAddress: config.CustomDNSAddress,
|
CustomDNSAddress: config.CustomDNSAddress,
|
||||||
RosenpassEnabled: config.RosenpassEnabled,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@@ -284,6 +259,83 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
|||||||
return loginResp, nil
|
return loginResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
|
||||||
|
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||||
|
// The check is performed only for the NetBird's managed version.
|
||||||
|
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||||
|
|
||||||
|
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() {
|
||||||
|
// only do the check for the NetBird's managed version
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var mgmTlsEnabled bool
|
||||||
|
if config.ManagementURL.Scheme == "https" {
|
||||||
|
mgmTlsEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !mgmTlsEnabled {
|
||||||
|
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
|
||||||
|
|
||||||
|
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||||
|
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||||
|
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||||
|
config.ManagementURL.String(), newURL.String())
|
||||||
|
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
|
return config, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
|
return config, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = client.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to close the Management service client %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// gRPC check
|
||||||
|
_, err = client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// everything is alright => update the config
|
||||||
|
newConfig, err := UpdateConfig(ConfigInput{
|
||||||
|
ManagementURL: newURL.String(),
|
||||||
|
ConfigPath: configPath,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
|
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||||
|
}
|
||||||
|
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||||
|
|
||||||
|
return newConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||||
var sri interface{} = statusRecorder
|
var sri interface{} = statusRecorder
|
||||||
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ import (
|
|||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
Provider string
|
Provider string
|
||||||
ProviderConfig DeviceAuthProviderConfig
|
ProviderConfig ProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
type DeviceAuthProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
// ClientID An IDP application client id
|
// ClientID An IDP application client id
|
||||||
ClientID string
|
ClientID string
|
||||||
// ClientSecret An IDP application client secret
|
// ClientSecret An IDP application client secret
|
||||||
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||||
|
|
||||||
ProviderConfig: DeviceAuthProviderConfig{
|
ProviderConfig: ProviderConfig{
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||||
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||||
}
|
}
|
||||||
|
|
||||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DeviceAuthorizationFlow{}, err
|
return DeviceAuthorizationFlow{}, err
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
|||||||
return deviceAuthorizationFlow, nil
|
return deviceAuthorizationFlow, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
func isProviderConfigValid(config ProviderConfig) error {
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
if config.Audience == "" {
|
if config.Audience == "" {
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,28 +1,29 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
||||||
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
|
fileGeneratedResolvConfSearchBeginContent = "search "
|
||||||
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
|
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
|
||||||
|
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||||
|
|
||||||
fileMaxLineCharsLimit = 256
|
|
||||||
fileMaxNumberOfSearchDomains = 6
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||||
|
fileMaxLineCharsLimit = 256
|
||||||
|
fileMaxNumberOfSearchDomains = 6
|
||||||
|
)
|
||||||
|
|
||||||
|
var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
|
||||||
|
|
||||||
type fileConfigurator struct {
|
type fileConfigurator struct {
|
||||||
originalPerms os.FileMode
|
originalPerms os.FileMode
|
||||||
}
|
}
|
||||||
@@ -31,18 +32,14 @@ func newFileConfigurator() (hostManager, error) {
|
|||||||
return &fileConfigurator{}, nil
|
return &fileConfigurator{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) supportCustomPort() bool {
|
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
backupFileExist := false
|
backupFileExist := false
|
||||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
backupFileExist = true
|
backupFileExist = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !config.RouteAll {
|
if !config.routeAll {
|
||||||
if backupFileExist {
|
if backupFileExist {
|
||||||
err = f.restore()
|
err = f.restore()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -51,39 +48,53 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
}
|
}
|
||||||
|
managerType, err := getOSDNSManagerType()
|
||||||
if !backupFileExist {
|
if err != nil {
|
||||||
err = f.backup()
|
return err
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("unable to backup the resolv.conf file")
|
switch managerType {
|
||||||
|
case fileManager, netbirdManager:
|
||||||
|
if !backupFileExist {
|
||||||
|
err = f.backup()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to backup the resolv.conf file")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
// todo improve this and maybe restart DNS manager from scratch
|
||||||
|
return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent")
|
||||||
}
|
}
|
||||||
|
|
||||||
searchDomainList := searchDomains(config)
|
var searchDomains string
|
||||||
|
appendedDomains := 0
|
||||||
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.matchOnly || dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||||
|
// lets log all skipped domains
|
||||||
|
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
|
||||||
|
// lets log all skipped domains
|
||||||
|
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
|
searchDomains += " " + dConf.domain
|
||||||
if err != nil {
|
appendedDomains++
|
||||||
log.Error(err)
|
|
||||||
}
|
}
|
||||||
|
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
||||||
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
|
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
|
||||||
|
|
||||||
buf := prepareResolvConfContent(
|
|
||||||
searchDomainList,
|
|
||||||
append([]string{config.ServerIP}, nameServers...),
|
|
||||||
others)
|
|
||||||
|
|
||||||
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
|
||||||
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
restoreErr := f.restore()
|
err = f.restore()
|
||||||
if restoreErr != nil {
|
if err != nil {
|
||||||
log.Errorf("attempt to restore default file failed with error: %s", err)
|
log.Errorf("attempt to restore default file failed with error: %s", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
|
return err
|
||||||
}
|
}
|
||||||
|
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)
|
||||||
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,138 +126,15 @@ func (f *fileConfigurator) restore() error {
|
|||||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
|
||||||
|
log.Debugf("creating managed file %s", fileName)
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
buf.WriteString(content)
|
||||||
|
err := os.WriteFile(fileName, buf.Bytes(), permissions)
|
||||||
for _, cfgLine := range others {
|
|
||||||
buf.WriteString(cfgLine)
|
|
||||||
buf.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(searchDomains) > 0 {
|
|
||||||
buf.WriteString("search ")
|
|
||||||
buf.WriteString(strings.Join(searchDomains, " "))
|
|
||||||
buf.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ns := range nameServers {
|
|
||||||
buf.WriteString("nameserver ")
|
|
||||||
buf.WriteString(ns)
|
|
||||||
buf.WriteString("\n")
|
|
||||||
}
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func searchDomains(config HostDNSConfig) []string {
|
|
||||||
listOfDomains := make([]string, 0)
|
|
||||||
for _, dConf := range config.Domains {
|
|
||||||
if dConf.MatchOnly || dConf.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
|
||||||
}
|
|
||||||
return listOfDomains
|
|
||||||
}
|
|
||||||
|
|
||||||
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
|
|
||||||
file, err := os.Open(resolvconfFile)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf(`could not read existing resolv.conf`)
|
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
defer file.Close()
|
return nil
|
||||||
|
|
||||||
reader := bufio.NewReader(file)
|
|
||||||
|
|
||||||
for {
|
|
||||||
lineBytes, isPrefix, readErr := reader.ReadLine()
|
|
||||||
if readErr != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if isPrefix {
|
|
||||||
err = fmt.Errorf(`resolv.conf line too long`)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
line := strings.TrimSpace(string(lineBytes))
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "domain") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
|
||||||
line = strings.ReplaceAll(line, "rotate", "")
|
|
||||||
splitLines := strings.Fields(line)
|
|
||||||
if len(splitLines) == 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
line = strings.Join(splitLines, " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "search") {
|
|
||||||
splitLines := strings.Fields(line)
|
|
||||||
if len(splitLines) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
searchDomains = splitLines[1:]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "nameserver") {
|
|
||||||
splitLines := strings.Fields(line)
|
|
||||||
if len(splitLines) != 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nameServers = append(nameServers, splitLines[1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
others = append(others, line)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// merge search Domains lists and cut off the list if it is too long
|
|
||||||
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
|
||||||
lineSize := len("search")
|
|
||||||
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
|
|
||||||
|
|
||||||
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
|
|
||||||
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
|
|
||||||
|
|
||||||
return searchDomainsList
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
|
|
||||||
// extend s slice with vs elements
|
|
||||||
// return with the number of characters in the searchDomains line
|
|
||||||
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
|
||||||
for _, sd := range vs {
|
|
||||||
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
|
||||||
if tmpCharsNumber > fileMaxLineCharsLimit {
|
|
||||||
// lets log all skipped Domains
|
|
||||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
initialLineChars = tmpCharsNumber
|
|
||||||
|
|
||||||
if len(*s) >= fileMaxNumberOfSearchDomains {
|
|
||||||
// lets log all skipped Domains
|
|
||||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
*s = append(*s, sd)
|
|
||||||
}
|
|
||||||
return initialLineChars
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyFile(src, dest string) error {
|
func copyFile(src, dest string) error {
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_mergeSearchDomains(t *testing.T) {
|
|
||||||
searchDomains := []string{"a", "b"}
|
|
||||||
originDomains := []string{"a", "b"}
|
|
||||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
|
||||||
if len(mergedDomains) != 4 {
|
|
||||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_mergeSearchTooMuchDomains(t *testing.T) {
|
|
||||||
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
|
|
||||||
originDomains := []string{"h", "i"}
|
|
||||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
|
||||||
if len(mergedDomains) != 6 {
|
|
||||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
|
|
||||||
searchDomains := []string{"a", "b"}
|
|
||||||
originDomains := []string{"c", "d", "e", "f", "g"}
|
|
||||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
|
||||||
if len(mergedDomains) != 6 {
|
|
||||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_mergeSearchTooLongDomain(t *testing.T) {
|
|
||||||
searchDomains := []string{getLongLine()}
|
|
||||||
originDomains := []string{"b"}
|
|
||||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
|
||||||
if len(mergedDomains) != 1 {
|
|
||||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
searchDomains = []string{"b"}
|
|
||||||
originDomains = []string{getLongLine()}
|
|
||||||
|
|
||||||
mergedDomains = mergeSearchDomains(searchDomains, originDomains)
|
|
||||||
if len(mergedDomains) != 1 {
|
|
||||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLongLine() string {
|
|
||||||
x := "search "
|
|
||||||
for {
|
|
||||||
for i := 0; i <= 9; i++ {
|
|
||||||
if len(x) > fileMaxLineCharsLimit {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
x = fmt.Sprintf("%s%d", x, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -8,31 +8,29 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig) error
|
applyDNSConfig(config hostDNSConfig) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostDNSConfig struct {
|
type hostDNSConfig struct {
|
||||||
Domains []DomainConfig `json:"domains"`
|
domains []domainConfig
|
||||||
RouteAll bool `json:"routeAll"`
|
routeAll bool
|
||||||
ServerIP string `json:"serverIP"`
|
serverIP string
|
||||||
ServerPort int `json:"serverPort"`
|
serverPort int
|
||||||
}
|
}
|
||||||
|
|
||||||
type DomainConfig struct {
|
type domainConfig struct {
|
||||||
Disabled bool `json:"disabled"`
|
disabled bool
|
||||||
Domain string `json:"domain"`
|
domain string
|
||||||
MatchOnly bool `json:"matchOnly"`
|
matchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockHostConfigurator struct {
|
type mockHostConfigurator struct {
|
||||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
applyDNSConfigFunc func(config hostDNSConfig) error
|
||||||
restoreHostDNSFunc func() error
|
restoreHostDNSFunc func() error
|
||||||
supportCustomPortFunc func() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
if m.applyDNSConfigFunc != nil {
|
if m.applyDNSConfigFunc != nil {
|
||||||
return m.applyDNSConfigFunc(config)
|
return m.applyDNSConfigFunc(config)
|
||||||
}
|
}
|
||||||
@@ -46,47 +44,39 @@ func (m *mockHostConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("method restoreHostDNS is not implemented")
|
return fmt.Errorf("method restoreHostDNS is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) supportCustomPort() bool {
|
|
||||||
if m.supportCustomPortFunc != nil {
|
|
||||||
return m.supportCustomPortFunc()
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNoopHostMocker() hostManager {
|
func newNoopHostMocker() hostManager {
|
||||||
return &mockHostConfigurator{
|
return &mockHostConfigurator{
|
||||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
|
||||||
restoreHostDNSFunc: func() error { return nil },
|
restoreHostDNSFunc: func() error { return nil },
|
||||||
supportCustomPortFunc: func() bool { return true },
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig {
|
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
|
||||||
config := HostDNSConfig{
|
config := hostDNSConfig{
|
||||||
RouteAll: false,
|
routeAll: false,
|
||||||
ServerIP: ip,
|
serverIP: ip,
|
||||||
ServerPort: port,
|
serverPort: port,
|
||||||
}
|
}
|
||||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||||
if len(nsConfig.NameServers) == 0 {
|
if len(nsConfig.NameServers) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if nsConfig.Primary {
|
if nsConfig.Primary {
|
||||||
config.RouteAll = true
|
config.routeAll = true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range nsConfig.Domains {
|
for _, domain := range nsConfig.Domains {
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.domains = append(config.domains, domainConfig{
|
||||||
Domain: strings.TrimSuffix(domain, "."),
|
domain: strings.TrimSuffix(domain, "."),
|
||||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
matchOnly: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.domains = append(config.domains, domainConfig{
|
||||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||||
MatchOnly: false,
|
matchOnly: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
type androidHostManager struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|
||||||
return &androidHostManager{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a androidHostManager) restoreHostDNS() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a androidHostManager) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !ios
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -10,6 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,21 +33,17 @@ type systemConfigurator struct {
|
|||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(_ WGIface) (hostManager, error) {
|
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) supportCustomPort() bool {
|
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.routeAll {
|
||||||
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
err = s.addDNSSetupForAll(config.serverIP, config.serverPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -58,7 +53,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.primaryServiceID = ""
|
s.primaryServiceID = ""
|
||||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -66,20 +61,20 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.Disabled {
|
if dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.matchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.Domain)
|
matchDomains = append(matchDomains, dConf.domain)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, dConf.domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("removing match domains from the system")
|
log.Infof("removing match domains from the system")
|
||||||
err = s.removeKeyFromSystemConfig(matchKey)
|
err = s.removeKeyFromSystemConfig(matchKey)
|
||||||
@@ -90,7 +85,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
if len(searchDomains) != 0 {
|
if len(searchDomains) != 0 {
|
||||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("removing search domains from the system")
|
log.Infof("removing search domains from the system")
|
||||||
err = s.removeKeyFromSystemConfig(searchKey)
|
err = s.removeKeyFromSystemConfig(searchKey)
|
||||||
@@ -184,11 +179,12 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||||
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
primaryServiceKey := s.getPrimaryService()
|
||||||
if primaryServiceKey == "" {
|
if primaryServiceKey == "" {
|
||||||
return fmt.Errorf("couldn't find the primary service key")
|
return fmt.Errorf("couldn't find the primary service key")
|
||||||
}
|
}
|
||||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
|
||||||
|
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -197,32 +193,27 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
func (s *systemConfigurator) getPrimaryService() string {
|
||||||
line := buildCommandLine("show", globalIPv4State, "")
|
line := buildCommandLine("show", globalIPv4State, "")
|
||||||
stdinCommands := wrapCommand(line)
|
stdinCommands := wrapCommand(line)
|
||||||
b, err := runSystemConfigCommand(stdinCommands)
|
b, err := runSystemConfigCommand(stdinCommands)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("got error while sending the command: ", err)
|
log.Error("got error while sending the command: ", err)
|
||||||
return "", ""
|
return ""
|
||||||
}
|
}
|
||||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||||
primaryService := ""
|
|
||||||
router := ""
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
if strings.Contains(text, "PrimaryService") {
|
if strings.Contains(text, "PrimaryService") {
|
||||||
primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
|
return strings.TrimSpace(strings.Split(text, ":")[1])
|
||||||
}
|
|
||||||
if strings.Contains(text, "Router") {
|
|
||||||
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return primaryService, router
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
|
||||||
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
||||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
|
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
|
||||||
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
||||||
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
||||||
stdinCommands := wrapCommand(addDomainCommand)
|
stdinCommands := wrapCommand(addDomainCommand)
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type iosHostManager struct {
|
|
||||||
dnsManager IosDnsManager
|
|
||||||
config HostDNSConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
|
||||||
return &iosHostManager{
|
|
||||||
dnsManager: dnsManager,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
jsonData, err := json.Marshal(config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
jsonString := string(jsonData)
|
|
||||||
log.Debugf("Applying DNS settings: %s", jsonString)
|
|
||||||
a.dnsManager.ApplyDns(jsonString)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a iosHostManager) restoreHostDNS() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a iosHostManager) supportCustomPort() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,30 +23,13 @@ const (
|
|||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func (t osManagerType) String() string {
|
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
switch t {
|
|
||||||
case netbirdManager:
|
|
||||||
return "netbird"
|
|
||||||
case fileManager:
|
|
||||||
return "file"
|
|
||||||
case networkManager:
|
|
||||||
return "networkManager"
|
|
||||||
case systemdManager:
|
|
||||||
return "systemd"
|
|
||||||
case resolvConfManager:
|
|
||||||
return "resolvconf"
|
|
||||||
default:
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("discovered mode is: %s", osManager)
|
log.Debugf("discovered mode is: %d", osManager)
|
||||||
switch osManager {
|
switch osManager {
|
||||||
case networkManager:
|
case networkManager:
|
||||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||||
@@ -82,6 +63,7 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
return netbirdManager, nil
|
return netbirdManager, nil
|
||||||
}
|
}
|
||||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
|
||||||
return networkManager, nil
|
return networkManager, nil
|
||||||
}
|
}
|
||||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
)
|
)
|
||||||
@@ -22,14 +23,16 @@ const (
|
|||||||
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
|
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryConfigurator struct {
|
type registryConfigurator struct {
|
||||||
guid string
|
guid string
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
existingSearchDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -39,14 +42,10 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *registryConfigurator) supportCustomPort() bool {
|
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
var err error
|
var err error
|
||||||
if config.RouteAll {
|
if config.routeAll {
|
||||||
err = r.addDNSSetupForAll(config.ServerIP)
|
err = r.addDNSSetupForAll(config.serverIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -56,7 +55,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.routingAll = false
|
r.routingAll = false
|
||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -64,18 +63,18 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.Disabled {
|
if dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !dConf.MatchOnly {
|
if !dConf.matchOnly {
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, dConf.domain)
|
||||||
}
|
}
|
||||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
matchDomains = append(matchDomains, "."+dConf.domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
err = r.addDNSMatchPolicy(matchDomains, config.serverIP)
|
||||||
} else {
|
} else {
|
||||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||||
}
|
}
|
||||||
@@ -146,11 +145,30 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey)
|
return r.updateSearchDomains([]string{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
valueList := strings.Split(value, ",")
|
||||||
|
setExisting := false
|
||||||
|
if len(r.existingSearchDomains) == 0 {
|
||||||
|
r.existingSearchDomains = valueList
|
||||||
|
setExisting = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(domains) == 0 && setExisting {
|
||||||
|
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newList := append(r.existingSearchDomains, domains...)
|
||||||
|
|
||||||
|
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("adding search domain failed with error: %s", err)
|
return fmt.Errorf("adding search domain failed with error: %s", err)
|
||||||
}
|
}
|
||||||
@@ -214,3 +232,33 @@ func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
|
||||||
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
||||||
|
}
|
||||||
|
defer regKey.Close()
|
||||||
|
|
||||||
|
val, _, err := regKey.GetStringValue(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
|
||||||
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
||||||
|
}
|
||||||
|
defer regKey.Close()
|
||||||
|
|
||||||
|
err = regKey.SetStringValue(key, value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,12 +2,10 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type registrationMap map[string]struct{}
|
type registrationMap map[string]struct{}
|
||||||
@@ -17,12 +15,9 @@ type localResolver struct {
|
|||||||
records sync.Map
|
records sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Tracef("received question: %#v", r.Question[0])
|
log.Tracef("received question: %#v\n", r.Question[0])
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||||
|
|||||||
@@ -2,23 +2,21 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
type MockServer struct {
|
type MockServer struct {
|
||||||
InitializeFunc func() error
|
StartFunc func()
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Start mock implementation of Start from Server interface
|
||||||
func (m *MockServer) Initialize() error {
|
func (m *MockServer) Start() {
|
||||||
if m.InitializeFunc != nil {
|
if m.StartFunc != nil {
|
||||||
return m.InitializeFunc()
|
m.StartFunc()
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop mock implementation of Stop from Server interface
|
// Stop mock implementation of Stop from Server interface
|
||||||
@@ -28,15 +26,6 @@ func (m *MockServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) DnsIP() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
|
||||||
// TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
if m.UpdateDNSServerFunc != nil {
|
if m.UpdateDNSServerFunc != nil {
|
||||||
@@ -44,7 +33,3 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) SearchDomains() []string {
|
|
||||||
return make([]string, 0)
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,14 +5,14 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"regexp"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbversion "github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -71,7 +69,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -90,11 +88,7 @@ func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
|
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
|
||||||
@@ -102,7 +96,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
|
|||||||
|
|
||||||
connSettings.cleanDeprecatedSettings()
|
connSettings.cleanDeprecatedSettings()
|
||||||
|
|
||||||
dnsIP, err := netip.ParseAddr(config.ServerIP)
|
dnsIP, err := netip.ParseAddr(config.serverIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
||||||
}
|
}
|
||||||
@@ -112,33 +106,33 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
|
|||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.Disabled {
|
if dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.matchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
newDomainList := append(searchDomains, matchDomains...)
|
||||||
|
|
||||||
priority := networkManagerDbusSearchDomainOnlyPriority
|
priority := networkManagerDbusSearchDomainOnlyPriority
|
||||||
switch {
|
switch {
|
||||||
case config.RouteAll:
|
case config.routeAll:
|
||||||
priority = networkManagerDbusPrimaryDNSPriority
|
priority = networkManagerDbusPrimaryDNSPriority
|
||||||
newDomainList = append(newDomainList, "~.")
|
newDomainList = append(newDomainList, "~.")
|
||||||
if !n.routingAll {
|
if !n.routingAll {
|
||||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||||
}
|
}
|
||||||
case len(matchDomains) > 0:
|
case len(matchDomains) > 0:
|
||||||
priority = networkManagerDbusWithMatchDomainPriority
|
priority = networkManagerDbusWithMatchDomainPriority
|
||||||
}
|
}
|
||||||
|
|
||||||
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
|
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
|
||||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||||
n.routingAll = false
|
n.routingAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,7 +284,12 @@ func isNetworkManagerSupportedVersion() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseVersion(inputVersion string) (*version.Version, error) {
|
func parseVersion(inputVersion string) (*version.Version, error) {
|
||||||
if inputVersion == "" || !nbversion.SemverRegexp.MatchString(inputVersion) {
|
reg, err := regexp.Compile(version.SemverRegexpRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if inputVersion == "" || !reg.MatchString(inputVersion) {
|
||||||
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
|
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
)
|
|
||||||
|
|
||||||
type notifier struct {
|
|
||||||
listener listener.NetworkChangeListener
|
|
||||||
listenerMux sync.Mutex
|
|
||||||
searchDomains []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNotifier(initialSearchDomains []string) *notifier {
|
|
||||||
sort.Strings(initialSearchDomains)
|
|
||||||
return ¬ifier{
|
|
||||||
searchDomains: initialSearchDomains,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
|
|
||||||
n.listenerMux.Lock()
|
|
||||||
defer n.listenerMux.Unlock()
|
|
||||||
n.listener = listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) onNewSearchDomains(searchDomains []string) {
|
|
||||||
sort.Strings(searchDomains)
|
|
||||||
|
|
||||||
if len(n.searchDomains) != len(searchDomains) {
|
|
||||||
n.searchDomains = searchDomains
|
|
||||||
n.notify()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflect.DeepEqual(n.searchDomains, searchDomains) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n.searchDomains = searchDomains
|
|
||||||
n.notify()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *notifier) notify() {
|
|
||||||
n.listenerMux.Lock()
|
|
||||||
defer n.listenerMux.Unlock()
|
|
||||||
if n.listener == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(l listener.NetworkChangeListener) {
|
|
||||||
l.OnNetworkChanged("")
|
|
||||||
}(n.listener)
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,34 +13,17 @@ const resolvconfCommand = "resolvconf"
|
|||||||
|
|
||||||
type resolvconf struct {
|
type resolvconf struct {
|
||||||
ifaceName string
|
ifaceName string
|
||||||
|
|
||||||
originalSearchDomains []string
|
|
||||||
originalNameServers []string
|
|
||||||
othersConfigs []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// supported "openresolv" only
|
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
|
||||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf")
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.Name(),
|
ifaceName: wgInterface.Name(),
|
||||||
originalSearchDomains: originalSearchDomains,
|
|
||||||
originalNameServers: nameServers,
|
|
||||||
othersConfigs: others,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) supportCustomPort() bool {
|
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|
||||||
var err error
|
var err error
|
||||||
if !config.RouteAll {
|
if !config.routeAll {
|
||||||
err = r.restoreHostDNS()
|
err = r.restoreHostDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@@ -49,20 +31,37 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
searchDomainList := searchDomains(config)
|
var searchDomains string
|
||||||
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
appendedDomains := 0
|
||||||
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.matchOnly || dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
buf := prepareResolvConfContent(
|
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||||
searchDomainList,
|
// lets log all skipped domains
|
||||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
|
||||||
r.othersConfigs)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
err = r.applyConfig(buf)
|
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
|
||||||
|
// lets log all skipped domains
|
||||||
|
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
searchDomains += " " + dConf.domain
|
||||||
|
appendedDomains++
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
||||||
|
|
||||||
|
err = r.applyConfig(content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
|
log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,12 +74,12 @@ func (r *resolvconf) restoreHostDNS() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
func (r *resolvconf) applyConfig(content string) error {
|
||||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||||
cmd.Stdin = &content
|
cmd.Stdin = strings.NewReader(content)
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,103 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
type responseWriter struct {
|
|
||||||
local net.Addr
|
|
||||||
remote net.Addr
|
|
||||||
packet gopacket.Packet
|
|
||||||
device tun.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr returns the net.Addr of the server
|
|
||||||
func (r *responseWriter) LocalAddr() net.Addr {
|
|
||||||
return r.local
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddr returns the net.Addr of the client that sent the current request.
|
|
||||||
func (r *responseWriter) RemoteAddr() net.Addr {
|
|
||||||
return r.remote
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteMsg writes a reply back to the client.
|
|
||||||
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
|
||||||
buff, err := msg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = r.Write(buff)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes a raw buffer back to the client.
|
|
||||||
func (r *responseWriter) Write(data []byte) (int, error) {
|
|
||||||
var ip gopacket.SerializableLayer
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
// Swap the source and destination addresses for the response
|
|
||||||
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
|
|
||||||
|
|
||||||
// Check if it's an IPv4 packet
|
|
||||||
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
|
|
||||||
ipv4 := ipv4Layer.(*layers.IPv4)
|
|
||||||
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
|
|
||||||
ip = ipv4
|
|
||||||
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
|
|
||||||
ipv6 := ipv6Layer.(*layers.IPv6)
|
|
||||||
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
|
|
||||||
ip = ipv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize the packet
|
|
||||||
buffer := gopacket.NewSerializeBuffer()
|
|
||||||
options := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
|
||||||
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to serialize packet: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
send := buffer.Bytes()
|
|
||||||
sendBuffer := make([]byte, 40, len(send)+40)
|
|
||||||
sendBuffer = append(sendBuffer, send...)
|
|
||||||
|
|
||||||
return r.device.Write([][]byte{sendBuffer}, 40)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the connection.
|
|
||||||
func (r *responseWriter) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TsigStatus returns the status of the Tsig.
|
|
||||||
func (r *responseWriter) TsigStatus() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TsigTimersOnly sets the tsig timers only boolean.
|
|
||||||
func (r *responseWriter) TsigTimersOnly(bool) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack lets the caller take over the connection.
|
|
||||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
|
||||||
func (r *responseWriter) Hijack() {
|
|
||||||
}
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface/mocks"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestResponseWriterLocalAddr(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
device := mocks.NewMockDevice(ctrl)
|
|
||||||
device.EXPECT().Write(gomock.Any(), gomock.Any())
|
|
||||||
|
|
||||||
request := &dns.Msg{
|
|
||||||
Question: []dns.Question{{
|
|
||||||
Name: "google.com.",
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.TypeA,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
|
||||||
replyMessage.SetReply(request)
|
|
||||||
replyMessage.RecursionAvailable = true
|
|
||||||
replyMessage.Rcode = dns.RcodeSuccess
|
|
||||||
replyMessage.Answer = []dns.RR{
|
|
||||||
&dns.A{
|
|
||||||
A: net.IPv4(8, 8, 8, 8),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
SrcIP: net.IPv4(127, 0, 0, 1),
|
|
||||||
DstIP: net.IPv4(127, 0, 0, 2),
|
|
||||||
}
|
|
||||||
udp := &layers.UDP{
|
|
||||||
DstPort: 53,
|
|
||||||
SrcPort: 45223,
|
|
||||||
}
|
|
||||||
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
|
||||||
t.Error("failed to set network layer for checksum")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize the packet
|
|
||||||
buffer := gopacket.NewSerializeBuffer()
|
|
||||||
options := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
requestData, err := request.Pack()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("got an error while packing the request message, error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload := gopacket.Payload(requestData)
|
|
||||||
|
|
||||||
if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil {
|
|
||||||
t.Errorf("failed to serialize packet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rw := &responseWriter{
|
|
||||||
local: &net.UDPAddr{
|
|
||||||
IP: net.IPv4(127, 0, 0, 1),
|
|
||||||
Port: 55223,
|
|
||||||
},
|
|
||||||
remote: &net.UDPAddr{
|
|
||||||
IP: net.IPv4(127, 0, 0, 1),
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
packet: gopacket.NewPacket(
|
|
||||||
buffer.Bytes(),
|
|
||||||
layers.LayerTypeIPv4,
|
|
||||||
gopacket.Default,
|
|
||||||
),
|
|
||||||
device: device,
|
|
||||||
}
|
|
||||||
if err := rw.WriteMsg(replyMessage); err != nil {
|
|
||||||
t.Errorf("got an error while writing the local resolver response, error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,527 +1,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
|
||||||
type ReadyListener interface {
|
|
||||||
OnReady()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IosDnsManager is a dns manager interface for iOS
|
|
||||||
type IosDnsManager interface {
|
|
||||||
ApplyDns(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Initialize() error
|
Start()
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
|
||||||
SearchDomains() []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
|
||||||
|
|
||||||
// DefaultServer dns server object
|
|
||||||
type DefaultServer struct {
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
service service
|
|
||||||
dnsMuxMap registeredHandlerMap
|
|
||||||
localResolver *localResolver
|
|
||||||
wgInterface WGIface
|
|
||||||
hostManager hostManager
|
|
||||||
updateSerial uint64
|
|
||||||
previousConfigHash uint64
|
|
||||||
currentConfig HostDNSConfig
|
|
||||||
|
|
||||||
// permanent related properties
|
|
||||||
permanent bool
|
|
||||||
hostsDnsList []string
|
|
||||||
hostsDnsListLock sync.Mutex
|
|
||||||
|
|
||||||
// make sense on mobile only
|
|
||||||
searchDomainNotifier *notifier
|
|
||||||
iosDnsManager IosDnsManager
|
|
||||||
}
|
|
||||||
|
|
||||||
type handlerWithStop interface {
|
|
||||||
dns.Handler
|
|
||||||
stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
type muxUpdate struct {
|
|
||||||
domain string
|
|
||||||
handler handlerWithStop
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
|
||||||
var addrPort *netip.AddrPort
|
|
||||||
if customAddress != "" {
|
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
|
||||||
}
|
|
||||||
addrPort = &parsedAddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
var dnsService service
|
|
||||||
if wgInterface.IsUserspaceBind() {
|
|
||||||
dnsService = newServiceViaMemory(wgInterface)
|
|
||||||
} else {
|
|
||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
|
||||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
|
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
|
||||||
ds.permanent = true
|
|
||||||
ds.hostsDnsList = hostsDnsList
|
|
||||||
ds.addHostRootZone()
|
|
||||||
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
|
|
||||||
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
|
|
||||||
ds.searchDomainNotifier.setListener(listener)
|
|
||||||
setServerDns(ds)
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServerIos returns a new dns server. It optimized for ios
|
|
||||||
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
|
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
|
||||||
ds.iosDnsManager = iosDnsManager
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
|
||||||
defaultServer := &DefaultServer{
|
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: stop,
|
|
||||||
service: dnsService,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
localResolver: &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultServer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if s.hostManager != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.permanent {
|
|
||||||
err = s.service.Listen()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.hostManager, err = s.initialize()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DnsIP returns the DNS resolver server IP address
|
|
||||||
//
|
|
||||||
// When kernel space interface used it return real DNS server listener IP address
|
|
||||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
|
||||||
func (s *DefaultServer) DnsIP() string {
|
|
||||||
return s.service.RuntimeIP()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *DefaultServer) Stop() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
s.ctxCancel()
|
|
||||||
|
|
||||||
if s.hostManager != nil {
|
|
||||||
err := s.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.service.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
|
||||||
s.hostsDnsListLock.Lock()
|
|
||||||
defer s.hostsDnsListLock.Unlock()
|
|
||||||
|
|
||||||
s.hostsDnsList = hostsDnsList
|
|
||||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
|
||||||
if ok {
|
|
||||||
log.Debugf("on new host DNS config but skip to apply it")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
|
||||||
s.addHostRootZone()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
|
||||||
return s.ctx.Err()
|
|
||||||
default:
|
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if s.hostManager == nil {
|
|
||||||
return fmt.Errorf("dns service is not initialized yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
|
||||||
ZeroNil: true,
|
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
UseStringer: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
|
||||||
s.updateSerial = serial
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateSerial = serial
|
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) SearchDomains() []string {
|
|
||||||
var searchDomains []string
|
|
||||||
|
|
||||||
for _, dConf := range s.currentConfig.Domains {
|
|
||||||
if dConf.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if dConf.MatchOnly {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
|
||||||
}
|
|
||||||
return searchDomains
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|
||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
|
||||||
if update.ServiceEnable {
|
|
||||||
_ = s.service.Listen()
|
|
||||||
} else if !s.permanent {
|
|
||||||
s.service.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
|
||||||
s.updateLocalResolver(localRecords)
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
|
||||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
|
||||||
hostUpdate.RouteAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.searchDomainNotifier != nil {
|
|
||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: customZone.Domain,
|
|
||||||
handler: s.localResolver,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
|
||||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
|
||||||
}
|
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
|
||||||
localRecords[key] = record
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, localRecords, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
|
||||||
if len(nsGroup.NameServers) == 0 {
|
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
|
||||||
}
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
|
||||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
|
||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// when upstream fails to resolve domain several times over all it servers
|
|
||||||
// it will calls this hook to exclude self from the configuration and
|
|
||||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
|
||||||
// because it is temporal deactivation until next try
|
|
||||||
//
|
|
||||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
|
||||||
|
|
||||||
var isContainRootUpdate bool
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.service.RegisterMux(update.domain, update.handler)
|
|
||||||
muxUpdateMap[update.domain] = update.handler
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
|
||||||
existingHandler.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.domain == nbdns.RootZone {
|
|
||||||
isContainRootUpdate = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
|
||||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
|
||||||
s.hostsDnsListLock.Lock()
|
|
||||||
s.addHostRootZone()
|
|
||||||
s.hostsDnsListLock.Unlock()
|
|
||||||
existingHandler.stop()
|
|
||||||
} else {
|
|
||||||
existingHandler.stop()
|
|
||||||
s.service.DeregisterMux(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for key, record := range update {
|
|
||||||
err := s.localResolver.registerRecord(record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
|
||||||
}
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
|
||||||
func (s *DefaultServer) upstreamCallbacks(
|
|
||||||
nsGroup *nbdns.NameServerGroup,
|
|
||||||
handler dns.Handler,
|
|
||||||
) (deactivate func(), reactivate func()) {
|
|
||||||
var removeIndex map[string]int
|
|
||||||
deactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Info("temporary deactivate nameservers group due timeout")
|
|
||||||
|
|
||||||
removeIndex = make(map[string]int)
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
removeIndex[domain] = -1
|
|
||||||
}
|
|
||||||
if nsGroup.Primary {
|
|
||||||
removeIndex[nbdns.RootZone] = -1
|
|
||||||
s.currentConfig.RouteAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
|
||||||
s.service.DeregisterMux(item.Domain)
|
|
||||||
removeIndex[item.Domain] = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
reactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
for domain, i := range removeIndex {
|
|
||||||
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
|
||||||
s.service.RegisterMux(domain, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Debug("reactivate temporary Disabled nameserver group")
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
s.currentConfig.RouteAll = true
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("reactivate temporary Disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) addHostRootZone() {
|
|
||||||
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
|
||||||
for n, ua := range s.hostsDnsList {
|
|
||||||
a, err := netip.ParseAddr(ua)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ipString := ua
|
|
||||||
if !a.Is4() {
|
|
||||||
ipString = fmt.Sprintf("[%s]", ua)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
|
||||||
}
|
|
||||||
handler.deactivate = func() {}
|
|
||||||
handler.reactivate = func() {}
|
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,32 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
import (
|
||||||
return newHostManager(s.wgInterface)
|
"context"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dummy dns server
|
||||||
|
type DefaultServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer On Android the DNS feature is not supported yet
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
return &DefaultServer{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start dummy implementation
|
||||||
|
func (s DefaultServer) Start() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop dummy implementation
|
||||||
|
func (s DefaultServer) Stop() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer dummy implementation
|
||||||
|
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build !ios
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
|
||||||
return newHostManager(s.wgInterface)
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user