Merge branch 'main' into local/engine-restart

This commit is contained in:
Pascal Fischer
2023-09-21 16:43:03 +02:00
168 changed files with 3724 additions and 2522 deletions

View File

@@ -0,0 +1,41 @@
name: Android build validation
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: "1.20.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v2
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -15,14 +15,14 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -18,13 +18,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -32,7 +32,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
@@ -47,13 +47,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -61,7 +61,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev

View File

@@ -8,14 +8,13 @@ jobs:
name: lint name: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - name: Checkout code
uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v3
with:
args: --timeout=6m

View File

@@ -0,0 +1,36 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -1,60 +0,0 @@
name: Test installation Darwin
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
env:
SKIP_UI_APP: true
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
install-all:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
echo "Error: NetBird UI is not installed"
exit 1
fi

View File

@@ -1,38 +0,0 @@
name: Test installation Linux
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest
strategy:
matrix:
check_bin_install: [true, false]
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename apt package
if: ${{ matrix.check_bin_install }}
run: |
sudo mv /usr/bin/apt /usr/bin/apt.bak
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi

View File

@@ -19,7 +19,7 @@ on:
- '**/Dockerfile.*' - '**/Dockerfile.*'
env: env:
SIGN_PIPE_VER: "v0.0.8" SIGN_PIPE_VER: "v0.0.9"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v1.14.1"
concurrency: concurrency:
@@ -29,20 +29,24 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
flags: ""
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20" go-version: "1.20"
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v1 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -56,10 +60,10 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- -
name: Set up QEMU name: Set up QEMU
uses: docker/setup-qemu-action@v1 uses: docker/setup-qemu-action@v2
- -
name: Set up Docker Buildx name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1 uses: docker/setup-buildx-action@v2
- -
name: Login to Docker hub name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
@@ -82,10 +86,10 @@ jobs:
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
- -
name: Run GoReleaser name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --rm-dist args: release --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -93,7 +97,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: release name: release
path: dist/ path: dist/
@@ -102,17 +106,19 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20" go-version: "1.20"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v1 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@@ -132,17 +138,17 @@ jobs:
- name: Generate windows rsrc - name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
@@ -151,19 +157,21 @@ jobs:
release_ui_darwin: release_ui_darwin:
runs-on: macos-11 runs-on: macos-11
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20" go-version: "1.20"
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v1 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@@ -175,15 +183,15 @@ jobs:
- -
name: Run GoReleaser name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/

View File

@@ -9,7 +9,6 @@ on:
- 'infrastructure_files/**' - 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml' - '.github/workflows/test-infrastructure-files.yml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true cancel-in-progress: true
@@ -25,12 +24,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -81,6 +80,7 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none" CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
run: | run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
@@ -92,6 +92,7 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073" grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$' grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
@@ -121,7 +122,7 @@ jobs:
- name: test running containers - name: test running containers
run: | run: |
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4 test $count -eq 4
working-directory: infrastructure_files working-directory: infrastructure_files

1
.gitignore vendored
View File

@@ -19,3 +19,4 @@ client/.distfiles/
infrastructure_files/setup.env infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store

54
.golangci.yaml Normal file
View File

@@ -0,0 +1,54 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
- path: sharedsock/filter.go
linters:
- unused
- path: client/firewall/iptables/rule.go
linters:
- unused
- path: mock.go
linters:
- nilnil

View File

@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
<strong>:hatching_chick: New Release! Peer expiration.</strong> <strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
<a href="https://github.com/netbirdio/netbird/releases"> <a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
Learn more Learn more
</a> </a>
</p> </p>
@@ -24,7 +24,7 @@
<p align="center"> <p align="center">
<strong> <strong>
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a> Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
@@ -40,9 +40,13 @@
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird isolates every machine and device by applying granular access policies, while allowing you to manage them intuitively from a single place. **Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
**Key features:** ### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Key features
| Connectivity | Management | Automation | Platforms | | Connectivity | Management | Automation | Platforms |
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------| |-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
@@ -57,10 +61,6 @@
| | <ul><li> - \[x] SSH access management </ul></li> | | | | | <ul><li> - \[x] SSH access management </ul></li> | | |
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -18,10 +18,9 @@ func Encode(num uint32) string {
} }
var encoded strings.Builder var encoded strings.Builder
remainder := uint32(0)
for num > 0 { for num > 0 {
remainder = num % base remainder := num % base
encoded.WriteByte(alphabet[remainder]) encoded.WriteByte(alphabet[remainder])
num /= base num /= base
} }

View File

@@ -1,7 +1,5 @@
FROM gcr.io/distroless/base:debug FROM alpine:3
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
SHELL ["/busybox/sh","-c"]
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
ENTRYPOINT [ "/go/bin/netbird","up"] ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird COPY netbird /go/bin/netbird

View File

@@ -55,7 +55,6 @@ type Client struct {
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
routeListener routemanager.RouteListener routeListener routemanager.RouteListener
onHostDnsFn func([]string)
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -97,7 +96,30 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {} return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
c.ctxCancelLock.Unlock()
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
} }

View File

@@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false supportsSSO = false
err = nil err = nil
} }

View File

@@ -3,21 +3,19 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/auth"
"strings" "strings"
"time" "time"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
) )
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
@@ -191,17 +189,16 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string var codeMsg string
if userCode != "" { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
if !strings.Contains(verificationURIComplete, userCode) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
} }
err := open.Run(verificationURIComplete) cmd.Println("Please do the SSO login in your browser. \n" +
cmd.Printf("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" + "If your browser didn't open automatically, use this URL to log in:\n\n" +
" " + verificationURIComplete + " " + codeMsg + " \n\n") verificationURIComplete + " " + codeMsg)
if err != nil { cmd.Println("")
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n") if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }

View File

@@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
resp, _ := getStatus(ctx, cmd) resp, err := getStatus(ctx, cmd)
if err != nil { if err != nil {
return nil return err
} }
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
@@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
" netbird up \n\n"+ " netbird up \n\n"+
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+ "If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+ "you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n", "More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
resp.GetStatus(), resp.GetStatus(),
) )
return nil return nil
@@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
outputInformationHolder := convertToStatusOutputOverview(resp) outputInformationHolder := convertToStatusOutputOverview(resp)
statusOutputString := "" var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder) statusOutputString = parseToFullDetailSummary(outputInformationHolder)

View File

@@ -76,12 +76,12 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil return nil, nil
} }
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -40,6 +40,9 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set

View File

@@ -44,6 +44,7 @@ type Manager struct {
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() iface.WGAddress
IsUserspaceBind() bool
} }
type ruleset struct { type ruleset struct {
@@ -52,7 +53,7 @@ type ruleset struct {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
m := &Manager{ m := &Manager{
wgIface: wgIface, wgIface: wgIface,
inputDefaultRuleSpecs: []string{ inputDefaultRuleSpecs: []string{
@@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
rulesets: make(map[string]ruleset), rulesets: make(map[string]ruleset),
} }
if err := ipset.Init(); err != nil { err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err) return nil, fmt.Errorf("init ipset: %w", err)
} }
// init clients for booth ipv4 and ipv6 // init clients for booth ipv4 and ipv6
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
} }
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) if m.ipv4Client == nil && m.ipv6Client == nil {
if err != nil { return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
} }
if err := m.Reset(); err != nil { if err := m.Reset(); err != nil {
@@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment is empty rule ID is used as comment // If comment is empty rule ID is used as comment
@@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
return err
}
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err) return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
} }
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err) return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
} }

View File

@@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set") panic("AddressFunc is not set")
} }
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) { func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
@@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -29,6 +29,8 @@ const (
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter" FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@@ -379,7 +381,7 @@ func (m *Manager) chain(
if c != nil { if c != nil {
return c, nil return c, nil
} }
return m.createChainIfNotExists(tf, name, hook, priority, cType) return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
} }
if ip.To4() != nil { if ip.To4() != nil {
@@ -399,13 +401,20 @@ func (m *Manager) chain(
} }
// table returns the table for the given family of the IP address // table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 { if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil { if m.tableIPv4 != nil {
return m.tableIPv4, nil return m.tableIPv4, nil
} }
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4) table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil return m.tableIPv6, nil
} }
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6) table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil return m.tableIPv6, nil
} }
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family) tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == FilterTableName { if t.Name == tableName {
return t, nil return t, nil
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return nil, err return nil, err
} }
@@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
func (m *Manager) createChainIfNotExists( func (m *Manager) createChainIfNotExists(
family nftables.TableFamily, family nftables.TableFamily,
tableName string,
name string, name string,
hooknum nftables.ChainHook, hooknum nftables.ChainHook,
priority nftables.ChainPriority, priority nftables.ChainPriority,
chainType nftables.ChainType, chainType nftables.ChainType,
) (*nftables.Chain, error) { ) (*nftables.Chain, error) {
table, err := m.table(family) table, err := m.table(family, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c) m.rConn.DelChain(c)
} }
@@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
return nil return nil
} }
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) { func (m *Manager) flushWithBackoff() (err error) {
backoff := 4 backoff := 4
backoffTime := 1000 * time.Millisecond backoffTime := 1000 * time.Millisecond
@@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
return nil return nil
} }
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func encodePort(port fw.Port) []byte { func encodePort(port fw.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))

View File

@@ -0,0 +1,19 @@
//go:build !windows && !linux
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}

View File

@@ -0,0 +1,21 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@@ -0,0 +1,91 @@
package uspfilter
import (
"errors"
"fmt"
"os/exec"
"strings"
"syscall"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
noRulesMatchCriteria = "No rules match the specified criteria"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, args ...string) error {
active, err := isFirewallRuleActive(ruleName)
if err != nil {
return err
}
if (action == addRule && !active) || (action == deleteRule && active) {
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
args := append(baseArgs, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
return nil
}
func isFirewallRuleActive(ruleName string) (bool, error) {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
output, err := cmd.Output()
if err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// if the firewall rule is not active, we expect last exit code to be 1
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
if exitStatus == 1 {
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
}
}
return false, err
}
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
return true, nil
}

View File

@@ -19,6 +19,7 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
} }
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
@@ -30,6 +31,8 @@ type Manager struct {
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
mutex sync.RWMutex mutex sync.RWMutex
} }
@@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
}, },
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet),
wgIface: iface,
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return nil return nil
} }
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@@ -16,6 +16,7 @@ import (
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },

View File

@@ -146,12 +146,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
// if this rule is member of rule selection with more than DefaultIPsCountForSet // if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it // it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)] ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" { if ipset.name == "" {
d.ipsetCounter++ d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter) ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
} }
ipsetName = ipset.name ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) log.Errorf("failed to apply firewall rule: %+v, %v", r, err)

View File

@@ -1,4 +1,4 @@
//go:build !linux //go:build !linux || android
package acl package acl
@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
) )
@@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil return newDefaultManager(fm), nil
} }
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)

View File

@@ -1,3 +1,5 @@
//go:build !android
package acl package acl
import ( import (
@@ -7,26 +9,68 @@ import (
"github.com/netbirdio/netbird/client/firewall/iptables" "github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
// Create creates a firewall manager instance for the Linux // Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (manager *DefaultManager, err error) { func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
// use userspace packet filtering firewall // use userspace packet filtering firewall
if fm, err = uspfilter.Create(iface); err != nil { usfm, err := uspfilter.Create(iface)
if err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err) log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err return nil, err
} }
} else {
if fm, err = nftables.Create(iface); err != nil { // set kernel space firewall Reset as hook for userspace firewall
log.Debugf("failed to create nftables manager: %s", err) // manager Reset method, to clean up
// fallback to iptables if resetHookForUserspace != nil {
if fm, err = iptables.Create(iface); err != nil { usfm.SetResetHook(resetHookForUserspace)
log.Errorf("failed to create iptables manager: %s", err)
return nil, err
}
} }
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err
} }
return newDefaultManager(fm), nil return newDefaultManager(fm), nil

View File

@@ -1,11 +1,13 @@
package acl package acl
import ( import (
"net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true) ifaceMock.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo") ifaceMock.EXPECT().SetFilter(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any()) ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface) acl, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create ACL manager: %v", err)
return return
@@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true) ifaceMock.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo") ifaceMock.EXPECT().SetFilter(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any()) ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface) acl, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create ACL manager: %v", err)
return return

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -57,34 +57,50 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken return t.AccessToken
} }
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. // NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("getting device authorization flow info") if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
return authenticateWithDeviceCodeFlow(ctx, config)
// Try to initialize the Device Authorization Flow
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
} }
log.Debugf("getting device authorization flow info failed with error: %v", err) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
log.Debugf("falling back to pkce authorization flow info") if err != nil {
// fallback to device code flow
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
// If Device Authorization Flow failed, try the PKCE Authorization Flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
s, ok := gstatus.FromError(err) s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound { if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " + return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " + "Please proceed with setting up this device using setup keys " +
"https://github.com/netbirdio/netbird/tree/main/management for details") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else if ok && s.Code() == codes.Unimplemented { } else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL) "please update your server or use Setup Keys to login", config.ManagementURL)
} else { } else {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
} }
} }
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
} }

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"html/template" "html/template"
"net" "net"
@@ -78,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
} }
// RequestAuthInfo requests a authorization code login flow information. // RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) { func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24) state, err := randomBytesInHex(24)
if err != nil { if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err) return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
@@ -112,60 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
tokenChan := make(chan *oauth2.Token, 1) tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1) errChan := make(chan error, 1)
go p.startServer(tokenChan, errChan) parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
}
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("failed to close the server: %v", err)
}
}()
go p.startServer(server, tokenChan, errChan)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return TokenInfo{}, ctx.Err() return TokenInfo{}, ctx.Err()
case token := <-tokenChan: case token := <-tokenChan:
return p.handleOAuthToken(token) return p.parseOAuthToken(token)
case err := <-errChan: case err := <-errChan:
return TokenInfo{}, err return TokenInfo{}, err
} }
} }
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) mux := http.NewServeMux()
if err != nil { mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err) token, err := p.handleRequest(req)
return
}
port := parsedURL.Port()
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
defer func() {
if err := server.Shutdown(context.Background()); err != nil {
log.Errorf("error while shutting down pkce flow server: %v", err)
}
}()
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
if err != nil { if err != nil {
renderPKCEFlowTmpl(w, err) renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
@@ -176,12 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
tokenChan <- token tokenChan <- token
}) })
if err := server.ListenAndServe(); err != nil { server.Handler = mux
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err errChan <- err
} }
} }
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) { func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{ tokenInfo := TokenInfo{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"os"
"reflect" "reflect"
"strings" "strings"
) )
@@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
return fmt.Errorf("invalid JWT token audience field") return fmt.Errorf("invalid JWT token audience field")
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@@ -0,0 +1,3 @@
//go:build !linux || android
package checkfw

View File

@@ -0,0 +1,56 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL) assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL) assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
if err != nil {
return
}
managementURL := "https://test.management.url:33071" managementURL := "https://test.management.url:33071"
adminURL := "https://app.admin.url:443" adminURL := "https://app.admin.url:443"
path := filepath.Join(t.TempDir(), "config.json") path := filepath.Join(t.TempDir(), "config.json")

View File

@@ -192,8 +192,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Print("Netbird engine started, my IP is: ", peerConfig.Address) log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected) state.Set(StatusConnected)
statusRecorder.ClientStart()
<-engineCtx.Done() <-engineCtx.Done()
statusRecorder.ClientTeardown() statusRecorder.ClientTeardown()
@@ -214,6 +212,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return nil return nil
} }
statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff) err = backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)

View File

@@ -238,7 +238,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate := s.currentConfig hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.routeAll = false hostUpdate.routeAll = false
} }

View File

@@ -777,7 +777,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
newNet, err := stdnet.NewNet(nil) newNet, err := stdnet.NewNet(nil)
if err != nil { if err != nil {
t.Fatalf("create stdnet: %v", err) t.Fatalf("create stdnet: %v", err)
return nil, nil return nil, err
} }
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)

View File

@@ -11,6 +11,9 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
) )
const ( const (
@@ -24,10 +27,11 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
customAddr *netip.AddrPort customAddr *netip.AddrPort
server *dns.Server server *dns.Server
runtimeIP string listenIP string
runtimePort int listenPort int
listenerIsRunning bool listenerIsRunning bool
listenerFlagLock sync.Mutex listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
} }
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener { func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
@@ -43,6 +47,7 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service
UDPSize: 65535, UDPSize: 65535,
}, },
} }
return s return s
} }
@@ -55,13 +60,21 @@ func (s *serviceViaListener) Listen() error {
} }
var err error var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress() s.listenIP, s.listenPort, err = s.evalListenAddress()
if err != nil { if err != nil {
log.Errorf("failed to eval runtime address: %s", err) log.Errorf("failed to eval runtime address: %s", err)
return err return err
} }
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
if s.shouldApplyPortFwd() {
s.ebpfService = ebpf.GetEbpfManagerInstance()
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
if err != nil {
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
s.ebpfService = nil
}
}
log.Debugf("starting dns on %s", s.server.Addr) log.Debugf("starting dns on %s", s.server.Addr)
go func() { go func() {
s.setListenerStatus(true) s.setListenerStatus(true)
@@ -69,9 +82,10 @@ func (s *serviceViaListener) Listen() error {
err := s.server.ListenAndServe() err := s.server.ListenAndServe()
if err != nil { if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
} }
}() }()
return nil return nil
} }
@@ -90,6 +104,13 @@ func (s *serviceViaListener) Stop() {
if err != nil { if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err) log.Errorf("stopping dns server listener returned an error: %v", err)
} }
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
} }
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -101,11 +122,18 @@ func (s *serviceViaListener) DeregisterMux(pattern string) {
} }
func (s *serviceViaListener) RuntimePort() int { func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil {
return defaultPort
} else {
return s.listenPort
}
} }
func (s *serviceViaListener) RuntimeIP() string { func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP return s.listenIP
} }
func (s *serviceViaListener) setListenerStatus(running bool) { func (s *serviceViaListener) setListenerStatus(running bool) {
@@ -136,10 +164,30 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
} }
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) { func (s *serviceViaListener) evalListenAddress() (string, int, error) {
if s.customAddr != nil { if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
} }
return s.getFirstListenerAvailable() return s.getFirstListenerAvailable()
} }
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
// resolution won't work.
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
// traffic on port 53 and forward it to a local DNS server running on 5053.
func (s *serviceViaListener) shouldApplyPortFwd() bool {
if runtime.GOOS != "linux" {
return false
}
if s.customAddr != nil {
return false
}
if s.listenPort == defaultPort {
return false
}
return true
}

View File

@@ -54,13 +54,16 @@ type bpfSpecs struct {
// //
// It can be passed ebpf.CollectionSpec.Assign. // It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct { type bpfProgramSpecs struct {
NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"` NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
} }
// bpfMapSpecs contains maps before they are loaded into the kernel. // bpfMapSpecs contains maps before they are loaded into the kernel.
// //
// It can be passed ebpf.CollectionSpec.Assign. // It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct { type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"` NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
} }
@@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error {
// //
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct { type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"` NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
} }
func (m *bpfMaps) Close() error { func (m *bpfMaps) Close() error {
return _BpfClose( return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap, m.NbWgProxySettingsMap,
) )
} }
@@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error {
// //
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct { type bpfPrograms struct {
NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"` NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
} }
func (p *bpfPrograms) Close() error { func (p *bpfPrograms) Close() error {
return _BpfClose( return _BpfClose(
p.NbWgProxy, p.NbXdpProg,
) )
} }

Binary file not shown.

View File

@@ -54,13 +54,16 @@ type bpfSpecs struct {
// //
// It can be passed ebpf.CollectionSpec.Assign. // It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct { type bpfProgramSpecs struct {
NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"` NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
} }
// bpfMapSpecs contains maps before they are loaded into the kernel. // bpfMapSpecs contains maps before they are loaded into the kernel.
// //
// It can be passed ebpf.CollectionSpec.Assign. // It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct { type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"` NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
} }
@@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error {
// //
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct { type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"` NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
} }
func (m *bpfMaps) Close() error { func (m *bpfMaps) Close() error {
return _BpfClose( return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap, m.NbWgProxySettingsMap,
) )
} }
@@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error {
// //
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct { type bpfPrograms struct {
NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"` NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
} }
func (p *bpfPrograms) Close() error { func (p *bpfPrograms) Close() error {
return _BpfClose( return _BpfClose(
p.NbWgProxy, p.NbXdpProg,
) )
} }

Binary file not shown.

View File

@@ -0,0 +1,51 @@
package ebpf
import (
"encoding/binary"
"net"
log "github.com/sirupsen/logrus"
)
const (
mapKeyDNSIP uint32 = 0
mapKeyDNSPort uint32 = 1
)
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagDnsForwarder)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeDNSFwd() error {
log.Debugf("free ebpf DNS forwarder")
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
}
func ip2int(ipString string) uint32 {
ip := net.ParseIP(ipString)
return binary.BigEndian.Uint32(ip.To4())
}

View File

@@ -0,0 +1,116 @@
package ebpf
import (
_ "embed"
"net"
"sync"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
mapKeyFeatures uint32 = 0
featureFlagWGProxy = 0b00000001
featureFlagDnsForwarder = 0b00000010
)
var (
singleton manager.Manager
singletonLock = &sync.Mutex{}
)
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
//
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
type GeneralManager struct {
lock sync.Mutex
link link.Link
featureFlags uint16
bpfObjs bpfObjects
}
// GetEbpfManagerInstance return a static eBpf Manager instance
func GetEbpfManagerInstance() manager.Manager {
singletonLock.Lock()
defer singletonLock.Unlock()
if singleton != nil {
return singleton
}
singleton = &GeneralManager{}
return singleton
}
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
tf.featureFlags = tf.featureFlags | feature
}
func (tf *GeneralManager) loadXdp() error {
if tf.link != nil {
return nil
}
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
iFace, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// load pre-compiled programs into the kernel.
err = loadBpfObjects(&tf.bpfObjs, nil)
if err != nil {
return err
}
tf.link, err = link.AttachXDP(link.XDPOptions{
Program: tf.bpfObjs.NbXdpProg,
Interface: iFace.Index,
})
if err != nil {
_ = tf.bpfObjs.Close()
tf.link = nil
return err
}
return nil
}
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
tf.lock.Lock()
defer tf.lock.Unlock()
tf.featureFlags &^= feature
if tf.link == nil {
return nil
}
if tf.featureFlags == 0 {
return tf.close()
}
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
}
func (tf *GeneralManager) close() error {
log.Debugf("detach ebpf program ")
err := tf.bpfObjs.Close()
if err != nil {
log.Warnf("failed to close eBpf objects: %s", err)
}
err = tf.link.Close()
tf.link = nil
return err
}

View File

@@ -0,0 +1,40 @@
package ebpf
import (
"testing"
)
func TestManager_setFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
if mgr.featureFlags != 1 {
t.Errorf("invalid faeture state")
}
mgr.setFeatureFlag(featureFlagDnsForwarder)
if mgr.featureFlags != 3 {
t.Errorf("invalid faeture state")
}
}
func TestManager_unsetFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
mgr.setFeatureFlag(featureFlagDnsForwarder)
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 2 {
t.Errorf("invalid faeture state, expected: %d, got: %d", 2, mgr.featureFlags)
}
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 0 {
t.Errorf("invalid faeture state, expected: %d, got: %d", 0, mgr.featureFlags)
}
}

View File

@@ -0,0 +1,64 @@
const __u32 map_key_dns_ip = 0;
const __u32 map_key_dns_port = 1;
struct bpf_map_def SEC("maps") nb_map_dns_ip = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u32),
.max_entries = 10,
};
struct bpf_map_def SEC("maps") nb_map_dns_port = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__be32 dns_ip = 0;
__be16 dns_port = 0;
// 13568 is 53 in big endian
__be16 GENERAL_DNS_PORT = 13568;
bool read_settings() {
__u16 *port_value;
__u32 *ip_value;
// read dns ip
ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip);
if(!ip_value) {
return false;
}
dns_ip = htonl(*ip_value);
// read dns port
port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port);
if (!port_value) {
return false;
}
dns_port = htons(*port_value);
return true;
}
int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) {
if (dns_port == 0) {
if(!read_settings()){
return XDP_PASS;
}
bpf_printk("dns port: %d", ntohs(dns_port));
bpf_printk("dns ip: %d", ntohl(dns_ip));
}
if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) {
udp->dest = dns_port;
return XDP_PASS;
}
if (udp->source == dns_port && ip->saddr == dns_ip) {
udp->source = GENERAL_DNS_PORT;
return XDP_PASS;
}
return XDP_PASS;
}

View File

@@ -0,0 +1,66 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "dns_fwd.c"
#include "wg_proxy.c"
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u16 flag_feature_wg_proxy = 0b01;
const __u16 flag_feature_dns_fwd = 0b10;
const __u32 map_key_features = 0;
struct bpf_map_def SEC("maps") nb_features = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
SEC("xdp")
int nb_xdp_prog(struct xdp_md *ctx) {
__u16 *features;
features = bpf_map_lookup_elem(&nb_features, &map_key_features);
if (!features) {
return XDP_PASS;
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
// skip non UPD packages
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
if (*features & flag_feature_dns_fwd) {
xdp_dns_fwd(ip, udp);
}
if (*features & flag_feature_wg_proxy) {
xdp_wg_proxy(ip, udp);
}
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@@ -0,0 +1,54 @@
const __u32 map_key_proxy_port = 0;
const __u32 map_key_wg_port = 1;
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__u16 proxy_port = 0;
__u16 wg_port = 0;
bool read_port_settings() {
__u16 *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
if (!value) {
return false;
}
proxy_port = *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
if (!value) {
return false;
}
wg_port = htons(*value);
return true;
}
int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) {
if (proxy_port == 0 || wg_port == 0) {
if (!read_port_settings()){
return XDP_PASS;
}
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
}
// 2130706433 = 127.0.0.1
if (ip->daddr != htonl(2130706433)) {
return XDP_PASS;
}
if (udp->source != wg_port){
return XDP_PASS;
}
__be16 new_src_port = udp->dest;
__be16 new_dst_port = htons(proxy_port);
udp->dest = new_dst_port;
udp->source = new_src_port;
return XDP_PASS;
}

View File

@@ -0,0 +1,41 @@
package ebpf
import log "github.com/sirupsen/logrus"
const (
mapKeyProxyPort uint32 = 0
mapKeyWgPort uint32 = 1
)
func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error {
log.Debugf("load ebpf WG proxy")
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
if err != nil {
return err
}
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagWGProxy)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeWGProxy() error {
log.Debugf("free ebpf WG proxy")
return tf.unsetFeatureFlag(featureFlagWGProxy)
}

View File

@@ -0,0 +1,15 @@
//go:build !android
package ebpf
import (
"github.com/netbirdio/netbird/client/internal/ebpf/ebpf"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal
// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant
// panic on older Android version.
func GetEbpfManagerInstance() manager.Manager {
return ebpf.GetEbpfManagerInstance()
}

View File

@@ -0,0 +1,10 @@
//go:build !linux || android
package ebpf
import "github.com/netbirdio/netbird/client/internal/ebpf/manager"
// GetEbpfManagerInstance return error because ebpf is not supported on all os
func GetEbpfManagerInstance() manager.Manager {
panic("unsupported os")
}

View File

@@ -0,0 +1,9 @@
package manager
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
type Manager interface {
LoadDNSFwd(ip string, dnsPort int) error
FreeDNSFwd() error
LoadWgProxy(proxyPort, wgPort int) error
FreeWGProxy() error
}

View File

@@ -995,14 +995,12 @@ func (e *Engine) parseNATExternalIPMappings() []string {
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping) log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
break break
} }
if externalIP != nil { mappedIP := externalIP.String()
mappedIP := externalIP.String() if internalIP != nil {
if internalIP != nil { mappedIP = mappedIP + "/" + internalIP.String()
mappedIP = mappedIP + "/" + internalIP.String()
}
mappedIPs = append(mappedIPs, mappedIP)
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
} }
mappedIPs = append(mappedIPs, mappedIP)
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
} }
if len(mappedIPs) != len(e.config.NATExternalIPs) { if len(mappedIPs) != len(e.config.NATExternalIPs) {
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings") log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")

View File

@@ -1046,15 +1046,15 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", nil return nil, "", err
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -61,7 +61,7 @@ func (n *notifier) clientStart() {
n.serverStateLock.Lock() n.serverStateLock.Lock()
defer n.serverStateLock.Unlock() defer n.serverStateLock.Unlock()
n.currentClientState = true n.currentClientState = true
n.lastNotification = stateConnected n.lastNotification = stateConnecting
n.notify(n.lastNotification) n.notify(n.lastNotification)
} }
@@ -114,7 +114,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
return stateConnected return stateConnected
} }
if !managementConn && !signalConn { if !managementConn && !signalConn && !n.currentClientState {
return stateDisconnected return stateDisconnected
} }

View File

@@ -155,7 +155,10 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey) state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil || state.ConnStatus != peer.StatusConnected { if err != nil {
return err
}
if state.ConnStatus != peer.StatusConnected {
return nil return nil
} }

View File

@@ -7,6 +7,8 @@ import (
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
const ( const (
@@ -26,15 +28,20 @@ func genKey(format string, input string) string {
return fmt.Sprintf(format, input) return fmt.Sprintf(format, input)
} }
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager // newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) firewallManager { func newFirewall(parentCTX context.Context) (firewallManager, error) {
manager, err := newNFTablesManager(parentCTX) checkResult := checkfw.Check()
if err == nil { switch checkResult {
log.Debugf("nftables firewall manager will be used") case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
return manager log.Debug("creating an iptables firewall manager for route rules")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
return newIptablesManager(parentCTX, ipv6Supported)
case checkfw.NFTABLES:
log.Info("creating an nftables firewall manager for route rules")
return newNFTablesManager(parentCTX), nil
} }
log.Debugf("fallback to iptables firewall manager: %s", err)
return newIptablesManager(parentCTX) return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
} }
func getInPair(pair routerPair) routerPair { func getInPair(pair routerPair) routerPair {

View File

@@ -3,24 +3,13 @@
package routemanager package routemanager
import "context" import (
"context"
"fmt"
"runtime"
)
type unimplementedFirewall struct{} // newFirewall returns a nil manager
func newFirewall(context.Context) (firewallManager, error) {
func (unimplementedFirewall) RestoreOrCreateContainers() error { return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
return nil
}
func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) CleanRoutingRules() {
}
// NewFirewall returns an unimplemented Firewall manager
func NewFirewall(parentCtx context.Context) firewallManager {
return unimplementedFirewall{}
} }

View File

@@ -49,30 +49,28 @@ type iptablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newIptablesManager(parentCtx context.Context) *iptablesManager { func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
log.Debugf("failed to initialize iptables for ipv4: %s", err) return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
} else if !isIptablesClientAvailable(ipv4Client) {
log.Infof("iptables is missing for ipv4")
ipv4Client = nil
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Debugf("failed to initialize iptables for ipv6: %s", err)
} else if !isIptablesClientAvailable(ipv6Client) {
log.Infof("iptables is missing for ipv6")
ipv6Client = nil
} }
return &iptablesManager{ ctx, cancel := context.WithCancel(parentCtx)
manager := &iptablesManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
ipv4Client: ipv4Client, ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string), rules: make(map[string]map[string][]string),
} }
if ipv6Supported {
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
}
}
return manager, nil
} }
// CleanRoutingRules cleans existing iptables resources that we created by the agent // CleanRoutingRules cleans existing iptables resources that we created by the agent
@@ -395,6 +393,10 @@ func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string
ipVersion = ipv6 ipVersion = ipv6
} }
if iptablesClient == nil {
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination) rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey] existingRule, found := i.rules[ipVersion][ruleKey]
@@ -459,6 +461,10 @@ func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair
ipVersion = ipv6 ipVersion = ipv6
} }
if iptablesClient == nil {
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey] existingRule, found := i.rules[ipVersion][ruleKey]
if found { if found {
@@ -479,8 +485,3 @@ func getIptablesRuleType(table string) string {
} }
return ruleType return ruleType
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
manager := newIptablesManager(context.TODO()) manager, err := newIptablesManager(context.TODO(), true)
require.NoError(t, err, "should return a valid iptables manager")
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")

View File

@@ -27,7 +27,7 @@ type DefaultManager struct {
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[string]*clientNetwork clientNetworks map[string]*clientNetwork
serverRouter *serverRouter serverRouter serverRouter
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
pubKey string pubKey string
@@ -36,13 +36,17 @@ type DefaultManager struct {
// NewManager returns a new route manager // NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) srvRouter, err := newServerRouter(ctx, wgInterface)
if err != nil {
log.Errorf("server router is not supported: %s", err)
}
mCTX, cancel := context.WithCancel(ctx)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface), serverRouter: srvRouter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
@@ -59,7 +63,9 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() { func (m *DefaultManager) Stop() {
m.stop() m.stop()
m.serverRouter.cleanUp() if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
m.ctx = nil m.ctx = nil
} }
@@ -77,9 +83,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.updateClientNetworks(updateSerial, newClientRoutesIDMap) m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap) m.notifier.onNewRoutes(newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if m.serverRouter != nil {
return err err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
} }
return nil return nil

View File

@@ -3,11 +3,12 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip" "net/netip"
"runtime" "runtime"
"testing" "testing"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -30,7 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputInitRoutes []*route.Route inputInitRoutes []*route.Route
inputRoutes []*route.Route inputRoutes []*route.Route
inputSerial uint64 inputSerial uint64
shouldCheckServerRoutes bool removeSrvRouter bool
serverRoutesExpected int serverRoutesExpected int
clientNetworkWatchersExpected int clientNetworkWatchersExpected int
}{ }{
@@ -87,7 +88,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
}, },
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2, serverRoutesExpected: 2,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
}, },
@@ -116,10 +116,38 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
}, },
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 1, serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1, clientNetworkWatchersExpected: 1,
}, },
{
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
removeSrvRouter: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 1,
},
{ {
name: "Should Create 1 HA Route and 1 Standalone", name: "Should Create 1 HA Route and 1 Standalone",
inputRoutes: []*route.Route{ inputRoutes: []*route.Route{
@@ -174,25 +202,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1, inputSerial: 1,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
}, },
{
name: "No Server Routes Should Be Added To Non Linux",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("1.2.3.4/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS != "linux",
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
{ {
name: "Remove 1 Client Route", name: "Remove 1 Client Route",
inputInitRoutes: []*route.Route{ inputInitRoutes: []*route.Route{
@@ -335,7 +344,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
inputRoutes: []*route.Route{}, inputRoutes: []*route.Route{},
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: true,
serverRoutesExpected: 0, serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
}, },
@@ -384,7 +392,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
}, },
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2, serverRoutesExpected: 2,
clientNetworkWatchersExpected: 1, clientNetworkWatchersExpected: 1,
}, },
@@ -409,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
defer routeManager.Stop() defer routeManager.Stop()
if testCase.removeSrvRouter {
routeManager.serverRouter = nil
}
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
@@ -419,8 +430,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes { if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") sr := routeManager.serverRouter.(*defaultServerRouter)
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
} }
}) })
} }

View File

@@ -86,10 +86,10 @@ type nftablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { func newNFTablesManager(parentCtx context.Context) *nftablesManager {
ctx, cancel := context.WithCancel(parentCtx) ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{ return &nftablesManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
conn: &nftables.Conn{}, conn: &nftables.Conn{},
@@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2), defaultForwardRules: make([]*nftables.Rule, 2),
} }
err := mgr.isSupported()
if err != nil {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
} }
// CleanRoutingRules cleans existing nftables rules from the system // CleanRoutingRules cleans existing nftables rules from the system
@@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error {
} }
for _, table := range tables { for _, table := range tables {
if table.Name == "filter" {
n.filterTable = table
continue
}
if table.Name == nftablesTable { if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 { if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table n.tableIPv4 = table
@@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil return nil
} }
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error { func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil { if n.defaultForwardRules[0] == nil {
return nil return nil
@@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
return nil return nil
} }
func (n *nftablesManager) isSupported() error {
_, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables is not supported: %s", err)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@@ -10,20 +10,23 @@ import (
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) if checkfw.Check() != checkfw.NFTABLES {
if err != nil { t.Skip("nftables not supported on this OS")
t.Fatalf("failed to create nftables manager: %s", err)
} }
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
@@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
} }
func TestNftablesManager_InsertRoutingRules(t *testing.T) { func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range insertRuleTestCases { for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) manager := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair) err = manager.InsertRoutingRules(testCase.inputPair)
@@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
} }
func TestNftablesManager_RemoveRoutingRules(t *testing.T) { func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range removeRuleTestCases { for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) manager := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4 table := manager.tableIPv4

View File

@@ -0,0 +1,9 @@
package routemanager
import "github.com/netbirdio/netbird/route"
type serverRouter interface {
updateRoutes(map[string]*route.Route) error
removeFromServerNetwork(*route.Route) error
cleanUp()
}

View File

@@ -2,20 +2,11 @@ package routemanager
import ( import (
"context" "context"
"fmt"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
) )
type serverRouter struct { func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
} }
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@@ -13,7 +13,7 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type serverRouter struct { type defaultServerRouter struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[string]*route.Route routes map[string]*route.Route
@@ -21,16 +21,21 @@ type serverRouter struct {
wgInterface *iface.WGIface wgInterface *iface.WGIface
} }
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
return &serverRouter{ firewall, err := newFirewall(ctx)
if err != nil {
return nil, err
}
return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[string]*route.Route), routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx), firewall: firewall,
wgInterface: wgInterface, wgInterface: wgInterface,
} }, nil
} }
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error { func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0) serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 { if len(routesMap) > 0 {
@@ -81,7 +86,7 @@ func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil return nil
} }
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not removing from server network because context is done") log.Infof("not removing from server network because context is done")
@@ -98,7 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
} }
} }
func (m *serverRouter) addToServerNetwork(route *route.Route) error { func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not adding to server network because context is done") log.Infof("not adding to server network because context is done")
@@ -115,6 +120,6 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
} }
} }
func (m *serverRouter) cleanUp() { func (m *defaultServerRouter) cleanUp() {
m.firewall.CleanRoutingRules() m.firewall.CleanRoutingRules()
} }

View File

@@ -20,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool {
for _, s := range disallowList { for _, s := range disallowList {
if strings.HasPrefix(iFace, s) { if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace) log.Tracef("ignoring interface %s - it is not allowed", iFace)
return false return false
} }
} }

View File

@@ -1,84 +0,0 @@
//go:build linux && !android
package ebpf
import (
_ "embed"
"net"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
)
const (
mapKeyProxyPort uint32 = 0
mapKeyWgPort uint32 = 1
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/portreplace.c --
// EBPF is a wrapper for eBPF program
type EBPF struct {
link link.Link
}
// NewEBPF create new EBPF instance
func NewEBPF() *EBPF {
return &EBPF{}
}
// Load load ebpf program
func (l *EBPF) Load(proxyPort, wgPort int) error {
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
ifce, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// Load pre-compiled programs into the kernel.
objs := bpfObjects{}
err = loadBpfObjects(&objs, nil)
if err != nil {
return err
}
defer func() {
_ = objs.Close()
}()
err = objs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
if err != nil {
return err
}
err = objs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
if err != nil {
return err
}
defer func() {
_ = objs.NbWgProxySettingsMap.Close()
}()
l.link, err = link.AttachXDP(link.XDPOptions{
Program: objs.NbWgProxy,
Interface: ifce.Index,
})
if err != nil {
return err
}
return err
}
// Free ebpf program
func (l *EBPF) Free() error {
if l.link != nil {
return l.link.Close()
}
return nil
}

View File

@@ -1,18 +0,0 @@
//go:build linux
package ebpf
import (
"testing"
)
func Test_newEBPF(t *testing.T) {
ebpf := NewEBPF()
err := ebpf.Load(1234, 51892)
defer func() {
_ = ebpf.Free()
}()
if err != nil {
t.Errorf("%s", err)
}
}

View File

@@ -1,90 +0,0 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u32 map_key_proxy_port = 0;
const __u32 map_key_wg_port = 1;
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__u16 proxy_port = 0;
__u16 wg_port = 0;
bool read_port_settings() {
__u16 *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
if(!value) {
return false;
}
proxy_port = *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
if(!value) {
return false;
}
wg_port = *value;
return true;
}
SEC("xdp")
int nb_wg_proxy(struct xdp_md *ctx) {
if(proxy_port == 0 || wg_port == 0) {
if(!read_port_settings()){
return XDP_PASS;
}
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
// 2130706433 = 127.0.0.1
if (ip->daddr != htonl(2130706433)) {
return XDP_PASS;
}
if (udp->source != htons(wg_port)){
return XDP_PASS;
}
__be16 new_src_port = udp->dest;
__be16 new_dst_port = htons(proxy_port);
udp->dest = new_dst_port;
udp->source = new_src_port;
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@@ -12,15 +12,15 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
ebpf2 "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
) )
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
ebpf *ebpf2.EBPF ebpfManager ebpfMgr.Manager
lastUsedPort uint16 lastUsedPort uint16
localWGListenPort int localWGListenPort int
@@ -36,7 +36,7 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
ebpf: ebpf2.NewEBPF(), ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0, lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
@@ -56,7 +56,7 @@ func (p *WGEBPFProxy) Listen() error {
return err return err
} }
err = p.ebpf.Load(wgPorxyPort, p.localWGListenPort) err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil { if err != nil {
return err return err
} }
@@ -110,7 +110,7 @@ func (p *WGEBPFProxy) Free() error {
err1 = p.conn.Close() err1 = p.conn.Close()
} }
err2 = p.ebpf.Free() err2 = p.ebpfManager.FreeWGProxy()
if p.rawConn != nil { if p.rawConn != nil {
err3 = p.rawConn.Close() err3 = p.rawConn.Close()
} }
@@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
} }
p.removeTurnConn(endpointPort) p.removeTurnConn(endpointPort)
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
return return
} }
err = p.sendPkg(buf[:n], endpointPort) err = p.sendPkg(buf[:n], endpointPort)
@@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
conn, ok := p.turnConnStore[uint16(addr.Port)] conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock() p.turnConnMutex.Unlock()
if !ok { if !ok {
log.Errorf("turn conn not found by port: %d", addr.Port) log.Infof("turn conn not found by port: %d", addr.Port)
continue continue
} }

77
client/netbird.wxs Normal file
View File

@@ -0,0 +1,77 @@
<Wix
xmlns="http://wixtoolset.org/schemas/v4/wxs">
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="Wiretrustee UG (haftungsbeschreankt)" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
<MediaTemplate EmbedCab="yes" />
<Feature Id="NetbirdFeature" Title="Netbird" Level="1">
<ComponentGroupRef Id="NetbirdFilesComponent" />
</Feature>
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
<StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File>
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" />
<ServiceInstall
Id="NetBirdService"
Name="NetBird"
DisplayName="NetBird"
Description="A WireGuard-based mesh network that connects your devices into a single private network."
Start="auto" Type="ownProcess"
ErrorControl="normal"
Account="LocalSystem"
Vital="yes"
Interactive="no"
Arguments='service run config [CommonAppDataFolder]Netbird\config.json log-level info'
/>
<ServiceControl Id="NetBirdService" Name="NetBird" Start="install" Stop="both" Remove="uninstall" Wait="yes" />
<Environment Id="UpdatePath" Name="PATH" Value="[NetbirdInstallDir]" Part="last" Action="set" System="yes" />
</Component>
</Directory>
</StandardDirectory>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
</ComponentGroup>
<Property Id="cmd" Value="cmd.exe"/>
<CustomAction Id="KillDaemon"
ExeCommand='/c "taskkill /im netbird.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<CustomAction Id="KillUI"
ExeCommand='/c "taskkill /im netbird-ui.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<InstallExecuteSequence>
<!-- For Uninstallation -->
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
</InstallExecuteSequence>
<!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\netbird.ico" />
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
</Package>
</Wix>

View File

@@ -3,10 +3,11 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/auth"
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/client/internal/auth"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@@ -315,7 +316,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo) tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
if err == context.Canceled { if err == context.Canceled {
return nil, nil return nil, nil //nolint:nilnil
} }
s.mutex.Lock() s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now() s.oauthAuthFlow.expiresAt = time.Now()

View File

@@ -130,16 +130,22 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
// Copy copies a nameserver group object // Copy copies a nameserver group object
func (g *NameServerGroup) Copy() *NameServerGroup { func (g *NameServerGroup) Copy() *NameServerGroup {
return &NameServerGroup{ nsGroup := &NameServerGroup{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Description: g.Description, Description: g.Description,
NameServers: g.NameServers, NameServers: make([]NameServer, len(g.NameServers)),
Groups: g.Groups, Groups: make([]string, len(g.Groups)),
Enabled: g.Enabled, Enabled: g.Enabled,
Primary: g.Primary, Primary: g.Primary,
Domains: g.Domains, Domains: make([]string, len(g.Domains)),
} }
copy(nsGroup.NameServers, g.NameServers)
copy(nsGroup.Groups, g.Groups)
copy(nsGroup.Domains, g.Domains)
return nsGroup
} }
// IsEqual compares one nameserver group with the other // IsEqual compares one nameserver group with the other

2
go.mod
View File

@@ -31,7 +31,7 @@ require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.1.4
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.10.0 github.com/cilium/ebpf v0.10.0
github.com/coreos/go-iptables v0.6.0 github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1 github.com/eko/gocache/v3 v3.1.1
github.com/getlantern/systray v1.2.1 github.com/getlantern/systray v1.2.1

4
go.sum
View File

@@ -131,8 +131,8 @@ github.com/containerd/typeurl v1.0.2/go.mod h1:9trJWW2sRlGub4wZJRTW83VtbOLS6hwcD
github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U= github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U=
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=

View File

@@ -112,7 +112,7 @@ func (w *WGIface) Close() error {
return w.tun.Close() return w.tun.Close()
} }
// SetFilter sets packet filters for the userspace impelemntation // SetFilter sets packet filters for the userspace implementation
func (w *WGIface) SetFilter(filter PacketFilter) error { func (w *WGIface) SetFilter(filter PacketFilter) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -161,7 +161,7 @@ func getModulePath(name string) (string, error) {
} }
if err != nil { if err != nil {
// skip broken files // skip broken files
return nil return nil //nolint:nilerr
} }
if !info.Type().IsRegular() { if !info.Type().IsRegular() {

View File

@@ -146,9 +146,6 @@ func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
} }
} }
if err != nil {
return err
}
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,

View File

@@ -36,7 +36,7 @@ services:
volumes: volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird - $SIGNAL_VOLUMENAME:/var/lib/netbird
ports: ports:
- 10000:80 - $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation # # port and command for Let's Encrypt validation
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]

View File

@@ -487,7 +487,48 @@ renderCaddyfile() {
} }
} }
(security_headers) {
header * {
# enable HSTS
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#strict-transport-security-hsts
# NOTE: Read carefully how this header works before using it.
# If the HSTS header is misconfigured or if there is a problem with
# the SSL/TLS certificate being used, legitimate users might be unable
# to access the website. For example, if the HSTS header is set to a
# very long duration and the SSL/TLS certificate expires or is revoked,
# legitimate users might be unable to access the website until
# the HSTS header duration has expired.
# The recommended value for the max-age is 2 year (63072000 seconds).
# But we are using 1 hour (3600 seconds) for testing purposes
# and ensure that the website is working properly before setting
# to two years.
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
# disable clients from sniffing the media type
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-content-type-options
X-Content-Type-Options "nosniff"
# clickjacking protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-frame-options
X-Frame-Options "DENY"
# xss protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-xss-protection
X-XSS-Protection "1; mode=block"
# Remove -Server header, which is an information leak
# Remove Caddy from Headers
-Server
# keep referrer data off of HTTP connections
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#referrer-policy
Referrer-Policy strict-origin-when-cross-origin
}
}
:80${CADDY_SECURE_DOMAIN} { :80${CADDY_SECURE_DOMAIN} {
import security_headers
# Signal # Signal
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management # Management

View File

@@ -22,3 +22,4 @@ NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"
NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP NETBIRD_MGMT_IDP=$CI_NETBIRD_MGMT_IDP
NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID NETBIRD_IDP_MGMT_CLIENT_ID=$CI_NETBIRD_IDP_MGMT_CLIENT_ID
NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET NETBIRD_IDP_MGMT_CLIENT_SECRET=$CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
NETBIRD_SIGNAL_PORT=12345

View File

@@ -61,12 +61,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager() peersUpdateManager := mgmt.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -19,28 +19,26 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. // ManagementLegacyPort is the port that was used before by the Management gRPC server.
@@ -72,10 +70,15 @@ var (
Use: "management", Use: "management",
Short: "start NetBird Management Server", Short: "start NetBird Management Server",
PreRunE: func(cmd *cobra.Command, args []string) error { PreRunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
// detect whether user specified a port // detect whether user specified a port
userPort := cmd.Flag("port").Changed userPort := cmd.Flag("port").Changed
var err error
config, err = loadMgmtConfig(mgmtConfig) config, err = loadMgmtConfig(mgmtConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
@@ -104,13 +107,7 @@ var (
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse() err := handleRebrand(cmd)
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
err = handleRebrand(cmd)
if err != nil { if err != nil {
return fmt.Errorf("failed to migrate files %v", err) return fmt.Errorf("failed to migrate files %v", err)
} }
@@ -146,12 +143,22 @@ var (
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, err := sqlite.NewSQLiteStore(config.Datadir) eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to initialize database: %s", err)
} }
if key != "" {
log.Debugf("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore) dnsDomain, eventStore, userDeleteFromIDPEnabled)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }
@@ -202,8 +209,11 @@ var (
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers()
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics) srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }
@@ -272,6 +282,7 @@ var (
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh
ephemeralManager.Stop()
_ = appMetrics.Close() _ = appMetrics.Close()
_ = listener.Close() _ = listener.Close()
if certManager != nil { if certManager != nil {
@@ -287,6 +298,20 @@ var (
} }
) )
func initEventStore(dataDir string, key string) (activity.Store, string, error) {
var err error
if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = sqlite.GenerateKey()
if err != nil {
return nil, "", err
}
}
store, err := sqlite.NewSQLiteStore(dataDir, key)
return store, key, err
}
func notifyStop(msg string) { func notifyStop(msg string) {
select { select {
case stopCh <- 1: case stopCh <- 1:
@@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err return loadedConfig, err
} }
func updateMgmtConfig(path string, config *server.Config) error {
return util.WriteJson(path, config)
}
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct { type OIDCConfigResponse struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`

View File

@@ -24,6 +24,7 @@ var (
disableMetrics bool disableMetrics bool
disableSingleAccMode bool disableSingleAccMode bool
idpSignKeyRefreshEnabled bool idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -56,6 +57,7 @@ func init() {
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -49,7 +49,7 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetOrCreateAccountByUser(userId, domain string) (*Account, error)
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error) autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, initiatorUserID string, targetUserID string) error DeleteUser(accountID, initiatorUserID string, targetUserID string) error
@@ -80,7 +80,6 @@ type AccountManager interface {
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error SaveGroup(accountID, userID string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
DeleteGroup(accountId, userId, groupID string) error DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error) ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error GroupAddPeer(accountId, groupID, peerID string) error
@@ -93,13 +92,11 @@ type AccountManager interface {
GetRoute(accountID, routeID, userID string) (*route.Route, error) GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error SaveRoute(accountID, userID string, route *route.Route) error
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID, userID string) error DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID, userID string) error DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string GetDNSDomain() string
@@ -133,6 +130,9 @@ type DefaultAccountManager struct {
// dnsDomain is used for peer resolution. This is appended to the peer's name // dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string dnsDomain string
peerLoginExpiry Scheduler peerLoginExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool
} }
// Settings represents Account settings structure that can be modified via API and Dashboard // Settings represents Account settings structure that can be modified via API and Dashboard
@@ -189,14 +189,15 @@ type Account struct {
} }
type UserInfo struct { type UserInfo struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Name string `json:"name"` Name string `json:"name"`
Role string `json:"role"` Role string `json:"role"`
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
Status string `json:"-"` Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"` IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"` IsBlocked bool `json:"is_blocked"`
LastLogin time.Time `json:"last_login"`
} }
// getRoutesToSync returns the enabled routes for the peer ID and the routes // getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -629,8 +630,8 @@ func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
// AddJWTGroups to account and to user autoassigned groups // SetJWTGroups to account and to user autoassigned groups
func (a *Account) AddJWTGroups(userID string, groups []string) bool { func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
user, ok := a.Users[userID] user, ok := a.Users[userID]
if !ok { if !ok {
return false return false
@@ -641,13 +642,21 @@ func (a *Account) AddJWTGroups(userID string, groups []string) bool {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
} }
autoGroups := make(map[string]struct{}) // remove JWT groups from the autogroups, to sync them again
for _, groupID := range user.AutoGroups { removed := 0
autoGroups[groupID] = struct{}{} jwtAutoGroups := make(map[string]struct{})
for i, id := range user.AutoGroups {
if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT {
jwtAutoGroups[group.Name] = struct{}{}
user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...)
removed++
}
} }
// create JWT groups if they doesn't exist
// and all of them to the autogroups
var modified bool var modified bool
for _, name := range groups { for _, name := range groupsNames {
group, ok := existedGroupsByName[name] group, ok := existedGroupsByName[name]
if !ok { if !ok {
group = &Group{ group = &Group{
@@ -656,16 +665,22 @@ func (a *Account) AddJWTGroups(userID string, groups []string) bool {
Issued: GroupIssuedJWT, Issued: GroupIssuedJWT,
} }
a.Groups[group.ID] = group a.Groups[group.ID] = group
modified = true
} }
if _, ok := autoGroups[group.ID]; !ok { // only JWT groups will be synced
if group.Issued == GroupIssuedJWT { if group.Issued == GroupIssuedJWT {
user.AutoGroups = append(user.AutoGroups, group.ID) user.AutoGroups = append(user.AutoGroups, group.ID)
if _, ok := jwtAutoGroups[name]; !ok {
modified = true modified = true
} }
delete(jwtAutoGroups, name)
} }
} }
// if not empty it means we removed some groups
if len(jwtAutoGroups) > 0 {
modified = true
}
return modified return modified
} }
@@ -723,18 +738,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
) (*DefaultAccountManager, error) { ) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
eventStore: eventStore, eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
} }
allAccounts := store.GetAllAccounts() allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
@@ -859,32 +875,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func()
return account.GetNextPeerExpiration() return account.GetNextPeerExpiration()
} }
expiredPeers := account.GetExpiredPeers()
var peerIDs []string var peerIDs []string
for _, peer := range account.GetExpiredPeers() { for _, peer := range expiredPeers {
if peer.Status.LoginExpired {
continue
}
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return account.GetNextPeerExpiration()
}
} }
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if len(peerIDs) != 0 { if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
// this will trigger peer disconnect from the management service log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
am.peersUpdateManager.CloseChannels(peerIDs) return account.GetNextPeerExpiration()
err = am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration()
}
} }
return account.GetNextPeerExpiration() return account.GetNextPeerExpiration()
} }
} }
@@ -1006,7 +1009,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
} }
} }
return nil, nil return nil, nil //nolint:nilnil
} }
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
@@ -1029,7 +1032,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
} }
} }
return nil, nil return nil, nil //nolint:nilnil
} }
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) { func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
@@ -1357,23 +1360,58 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
} }
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok { if slice, ok := claim.([]interface{}); ok {
var groups []string var groupsNames []string
for _, item := range slice { for _, item := range slice {
if g, ok := item.(string); ok { if g, ok := item.(string); ok {
groups = append(groups, g) groupsNames = append(groupsNames, g)
} else { } else {
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
} }
} }
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// if groups were added or modified, save the account // if groups were added or modified, save the account
if account.AddJWTGroups(claims.UserId, groups) { if account.SetJWTGroups(claims.UserId, groupsNames) {
if account.Settings.GroupsPropagationEnabled { if account.Settings.GroupsPropagationEnabled {
if user, err := account.FindUser(claims.UserId); err == nil { if user, err := account.FindUser(claims.UserId); err == nil {
account.UserGroupsAddToPeers(claims.UserId, append(user.AutoGroups, groups...)...) addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
} else {
if err := am.updateAccountPeers(account); err != nil {
log.Errorf("failed updating account peers while updating user %s", account.Id)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
} }
}
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
} }
} }
} else { } else {
@@ -1420,7 +1458,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
if _, ok := accountFromID.Users[claims.UserId]; !ok { if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
} }
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory { if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil return accountFromID, nil
} }
} }
@@ -1554,19 +1592,3 @@ func newAccountWithId(accountID, userID, domain string) *Account {
} }
return acc return acc
} }
func removeFromList(inputList []string, toRemove []string) []string {
toRemoveMap := make(map[string]struct{})
for _, item := range toRemove {
toRemoveMap[item] = struct{}{}
}
var resultList []string
for _, item := range inputList {
_, ok := toRemoveMap[item]
if !ok {
resultList = append(resultList, item)
}
}
return resultList
}

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"encoding/json"
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
@@ -11,6 +12,7 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -781,11 +783,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0 serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID) setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
return
}
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -928,11 +926,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID) setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
return
}
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -1112,11 +1106,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID) setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
return
}
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -1348,6 +1338,11 @@ func TestAccount_Copy(t *testing.T) {
Peers: map[string]*Peer{ Peers: map[string]*Peer{
"peer1": { "peer1": {
Key: "key1", Key: "key1",
Status: &PeerStatus{
LastSeen: time.Now(),
Connected: true,
LoginExpired: false,
},
}, },
}, },
Users: map[string]*User{ Users: map[string]*User{
@@ -1370,28 +1365,36 @@ func TestAccount_Copy(t *testing.T) {
}, },
Groups: map[string]*Group{ Groups: map[string]*Group{
"group1": { "group1": {
ID: "group1", ID: "group1",
Peers: []string{"peer1"},
}, },
}, },
Rules: map[string]*Rule{ Rules: map[string]*Rule{
"rule1": { "rule1": {
ID: "rule1", ID: "rule1",
Destination: []string{},
Source: []string{},
}, },
}, },
Policies: []*Policy{ Policies: []*Policy{
{ {
ID: "policy1", ID: "policy1",
Enabled: true, Enabled: true,
Rules: make([]*PolicyRule, 0),
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[string]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Groups: []string{"group1"},
}, },
}, },
NameServerGroups: map[string]*nbdns.NameServerGroup{ NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": { "nsGroup1": {
ID: "nsGroup1", ID: "nsGroup1",
Domains: []string{},
Groups: []string{},
NameServers: []nbdns.NameServer{},
}, },
}, },
DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}}, DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}},
@@ -1402,10 +1405,20 @@ func TestAccount_Copy(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
accountCopy := account.Copy() accountCopy := account.Copy()
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected") accBytes, err := json.Marshal(account)
if err != nil {
t.Fatal(err)
}
account.Peers["peer1"].Status.Connected = false // we change original object to confirm that copy wont change
accCopyBytes, err := json.Marshal(accountCopy)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, string(accBytes), string(accCopyBytes), "account copy returned a different value than expected")
} }
// hasNilField validates pointers, maps and slices if they are nil // hasNilField validates pointers, maps and slices if they are nil
// TODO: make it check nested fields too
func hasNilField(x interface{}) error { func hasNilField(x interface{}) error {
rv := reflect.ValueOf(x) rv := reflect.ValueOf(x)
rv = rv.Elem() rv = rv.Elem()
@@ -1930,7 +1943,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
} }
} }
func TestAccount_AddJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
// create a new account // create a new account
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*Peer{
@@ -1951,13 +1964,13 @@ func TestAccount_AddJWTGroups(t *testing.T) {
} }
t.Run("api group already exists", func(t *testing.T) { t.Run("api group already exists", func(t *testing.T) {
updated := account.AddJWTGroups("user1", []string{"group1"}) updated := account.SetJWTGroups("user1", []string{"group1"})
assert.False(t, updated, "account should not be updated") assert.False(t, updated, "account should not be updated")
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty")
}) })
t.Run("add jwt group", func(t *testing.T) { t.Run("add jwt group", func(t *testing.T) {
updated := account.AddJWTGroups("user1", []string{"group1", "group2"}) updated := account.SetJWTGroups("user1", []string{"group1", "group2"})
assert.True(t, updated, "account should be updated") assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 2, "new group should be added") assert.Len(t, account.Groups, 2, "new group should be added")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "new group should be added") assert.Len(t, account.Users["user1"].AutoGroups, 1, "new group should be added")
@@ -1965,13 +1978,13 @@ func TestAccount_AddJWTGroups(t *testing.T) {
}) })
t.Run("existed group not update", func(t *testing.T) { t.Run("existed group not update", func(t *testing.T) {
updated := account.AddJWTGroups("user1", []string{"group2"}) updated := account.SetJWTGroups("user1", []string{"group2"})
assert.False(t, updated, "account should not be updated") assert.False(t, updated, "account should not be updated")
assert.Len(t, account.Groups, 2, "groups count should not be changed") assert.Len(t, account.Groups, 2, "groups count should not be changed")
}) })
t.Run("add new group", func(t *testing.T) { t.Run("add new group", func(t *testing.T) {
updated := account.AddJWTGroups("user2", []string{"group1", "group3"}) updated := account.SetJWTGroups("user2", []string{"group1", "group3"})
assert.True(t, updated, "account should be updated") assert.True(t, updated, "account should be updated")
assert.Len(t, account.Groups, 3, "new group should be added") assert.Len(t, account.Groups, 3, "new group should be added")
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
@@ -2050,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore) return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false)
} }
func createStore(t *testing.T) (Store, error) { func createStore(t *testing.T) (Store, error) {

View File

@@ -1,5 +1,14 @@
package activity package activity
// Activity that triggered an Event
type Activity int
// Code is an activity string representation
type Code struct {
message string
code string
}
const ( const (
// PeerAddedByUser indicates that a user added a new peer to the system // PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota PeerAddedByUser Activity = iota
@@ -95,316 +104,85 @@ const (
UserBlocked UserBlocked
// UserUnblocked indicates that a user unblocked another user // UserUnblocked indicates that a user unblocked another user
UserUnblocked UserUnblocked
// UserDeleted indicates that a user deleted another user
UserDeleted
// GroupDeleted indicates that a user deleted group // GroupDeleted indicates that a user deleted group
GroupDeleted GroupDeleted
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin
) )
const ( var activityMap = map[Activity]Code{
// PeerAddedByUserMessage is a human-readable text message of the PeerAddedByUser activity PeerAddedByUser: {"Peer added", "user.peer.add"},
PeerAddedByUserMessage string = "Peer added" PeerAddedWithSetupKey: {"Peer added", "setupkey.peer.add"},
// PeerAddedWithSetupKeyMessage is a human-readable text message of the PeerAddedWithSetupKey activity UserJoined: {"User joined", "user.join"},
PeerAddedWithSetupKeyMessage = PeerAddedByUserMessage UserInvited: {"User invited", "user.invite"},
// UserJoinedMessage is a human-readable text message of the UserJoined activity AccountCreated: {"Account created", "account.create"},
UserJoinedMessage string = "User joined" PeerRemovedByUser: {"Peer deleted", "user.peer.delete"},
// UserInvitedMessage is a human-readable text message of the UserInvited activity RuleAdded: {"Rule added", "rule.add"},
UserInvitedMessage string = "User invited" RuleUpdated: {"Rule updated", "rule.update"},
// AccountCreatedMessage is a human-readable text message of the AccountCreated activity RuleRemoved: {"Rule deleted", "rule.delete"},
AccountCreatedMessage string = "Account created" PolicyAdded: {"Policy added", "policy.add"},
// PeerRemovedByUserMessage is a human-readable text message of the PeerRemovedByUser activity PolicyUpdated: {"Policy updated", "policy.update"},
PeerRemovedByUserMessage string = "Peer deleted" PolicyRemoved: {"Policy deleted", "policy.delete"},
// RuleAddedMessage is a human-readable text message of the RuleAdded activity SetupKeyCreated: {"Setup key created", "setupkey.add"},
RuleAddedMessage string = "Rule added" SetupKeyUpdated: {"Setup key updated", "setupkey.update"},
// RuleRemovedMessage is a human-readable text message of the RuleRemoved activity SetupKeyRevoked: {"Setup key revoked", "setupkey.revoke"},
RuleRemovedMessage string = "Rule deleted" SetupKeyOverused: {"Setup key overused", "setupkey.overuse"},
// RuleUpdatedMessage is a human-readable text message of the RuleRemoved activity GroupCreated: {"Group created", "group.add"},
RuleUpdatedMessage string = "Rule updated" GroupUpdated: {"Group updated", "group.update"},
// PolicyAddedMessage is a human-readable text message of the PolicyAdded activity GroupAddedToPeer: {"Group added to peer", "peer.group.add"},
PolicyAddedMessage string = "Policy added" GroupRemovedFromPeer: {"Group removed from peer", "peer.group.delete"},
// PolicyRemovedMessage is a human-readable text message of the PolicyRemoved activity GroupAddedToUser: {"Group added to user", "user.group.add"},
PolicyRemovedMessage string = "Policy deleted" GroupRemovedFromUser: {"Group removed from user", "user.group.delete"},
// PolicyUpdatedMessage is a human-readable text message of the PolicyRemoved activity UserRoleUpdated: {"User role updated", "user.role.update"},
PolicyUpdatedMessage string = "Policy updated" GroupAddedToSetupKey: {"Group added to setup key", "setupkey.group.add"},
// SetupKeyCreatedMessage is a human-readable text message of the SetupKeyCreated activity GroupRemovedFromSetupKey: {"Group removed from user setup key", "setupkey.group.delete"},
SetupKeyCreatedMessage string = "Setup key created" GroupAddedToDisabledManagementGroups: {"Group added to disabled management DNS setting", "dns.setting.disabled.management.group.add"},
// SetupKeyUpdatedMessage is a human-readable text message of the SetupKeyUpdated activity GroupRemovedFromDisabledManagementGroups: {"Group removed from disabled management DNS setting", "dns.setting.disabled.management.group.delete"},
SetupKeyUpdatedMessage string = "Setup key updated" RouteCreated: {"Route created", "route.add"},
// SetupKeyRevokedMessage is a human-readable text message of the SetupKeyRevoked activity RouteRemoved: {"Route deleted", "route.delete"},
SetupKeyRevokedMessage string = "Setup key revoked" RouteUpdated: {"Route updated", "route.update"},
// SetupKeyOverusedMessage is a human-readable text message of the SetupKeyOverused activity PeerSSHEnabled: {"Peer SSH server enabled", "peer.ssh.enable"},
SetupKeyOverusedMessage string = "Setup key overused" PeerSSHDisabled: {"Peer SSH server disabled", "peer.ssh.disable"},
// GroupCreatedMessage is a human-readable text message of the GroupCreated activity PeerRenamed: {"Peer renamed", "peer.rename"},
GroupCreatedMessage string = "Group created" PeerLoginExpirationEnabled: {"Peer login expiration enabled", "peer.login.expiration.enable"},
// GroupUpdatedMessage is a human-readable text message of the GroupUpdated activity PeerLoginExpirationDisabled: {"Peer login expiration disabled", "peer.login.expiration.disable"},
GroupUpdatedMessage string = "Group updated" NameserverGroupCreated: {"Nameserver group created", "nameserver.group.add"},
// GroupAddedToPeerMessage is a human-readable text message of the GroupAddedToPeer activity NameserverGroupDeleted: {"Nameserver group deleted", "nameserver.group.delete"},
GroupAddedToPeerMessage string = "Group added to peer" NameserverGroupUpdated: {"Nameserver group updated", "nameserver.group.update"},
// GroupRemovedFromPeerMessage is a human-readable text message of the GroupRemovedFromPeer activity AccountPeerLoginExpirationDurationUpdated: {"Account peer login expiration duration updated", "account.setting.peer.login.expiration.update"},
GroupRemovedFromPeerMessage string = "Group removed from peer" AccountPeerLoginExpirationEnabled: {"Account peer login expiration enabled", "account.setting.peer.login.expiration.enable"},
// GroupAddedToUserMessage is a human-readable text message of the GroupAddedToUser activity AccountPeerLoginExpirationDisabled: {"Account peer login expiration disabled", "account.setting.peer.login.expiration.disable"},
GroupAddedToUserMessage string = "Group added to user" PersonalAccessTokenCreated: {"Personal access token created", "personal.access.token.create"},
// GroupRemovedFromUserMessage is a human-readable text message of the GroupRemovedFromUser activity PersonalAccessTokenDeleted: {"Personal access token deleted", "personal.access.token.delete"},
GroupRemovedFromUserMessage string = "Group removed from user" ServiceUserCreated: {"Service user created", "service.user.create"},
// UserRoleUpdatedMessage is a human-readable text message of the UserRoleUpdatedMessage activity ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
UserRoleUpdatedMessage string = "User role updated" UserBlocked: {"User blocked", "user.block"},
// GroupAddedToSetupKeyMessage is a human-readable text message of the GroupAddedToSetupKey activity UserUnblocked: {"User unblocked", "user.unblock"},
GroupAddedToSetupKeyMessage string = "Group added to setup key" UserDeleted: {"User deleted", "user.delete"},
// GroupRemovedFromSetupKeyMessage is a human-readable text message of the GroupRemovedFromSetupKey activity GroupDeleted: {"Group deleted", "group.delete"},
GroupRemovedFromSetupKeyMessage string = "Group removed from user setup key" UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
// GroupAddedToDisabledManagementGroupsMessage is a human-readable text message of the GroupAddedToDisabledManagementGroups activity PeerLoginExpired: {"Peer login expired", "peer.login.expire"},
GroupAddedToDisabledManagementGroupsMessage string = "Group added to disabled management DNS setting" DashboardLogin: {"Dashboard login", "dashboard.login"},
// GroupRemovedFromDisabledManagementGroupsMessage is a human-readable text message of the GroupRemovedFromDisabledManagementGroups activity
GroupRemovedFromDisabledManagementGroupsMessage string = "Group removed from disabled management DNS setting"
// RouteCreatedMessage is a human-readable text message of the RouteCreated activity
RouteCreatedMessage string = "Route created"
// RouteRemovedMessage is a human-readable text message of the RouteRemoved activity
RouteRemovedMessage string = "Route deleted"
// RouteUpdatedMessage is a human-readable text message of the RouteUpdated activity
RouteUpdatedMessage string = "Route updated"
// PeerSSHEnabledMessage is a human-readable text message of the PeerSSHEnabled activity
PeerSSHEnabledMessage string = "Peer SSH server enabled"
// PeerSSHDisabledMessage is a human-readable text message of the PeerSSHDisabled activity
PeerSSHDisabledMessage string = "Peer SSH server disabled"
// PeerRenamedMessage is a human-readable text message of the PeerRenamed activity
PeerRenamedMessage string = "Peer renamed"
// PeerLoginExpirationDisabledMessage is a human-readable text message of the PeerLoginExpirationDisabled activity
PeerLoginExpirationDisabledMessage string = "Peer login expiration disabled"
// PeerLoginExpirationEnabledMessage is a human-readable text message of the PeerLoginExpirationEnabled activity
PeerLoginExpirationEnabledMessage string = "Peer login expiration enabled"
// NameserverGroupCreatedMessage is a human-readable text message of the NameserverGroupCreated activity
NameserverGroupCreatedMessage string = "Nameserver group created"
// NameserverGroupDeletedMessage is a human-readable text message of the NameserverGroupDeleted activity
NameserverGroupDeletedMessage string = "Nameserver group deleted"
// NameserverGroupUpdatedMessage is a human-readable text message of the NameserverGroupUpdated activity
NameserverGroupUpdatedMessage string = "Nameserver group updated"
// AccountPeerLoginExpirationEnabledMessage is a human-readable text message of the AccountPeerLoginExpirationEnabled activity
AccountPeerLoginExpirationEnabledMessage string = "Peer login expiration enabled for the account"
// AccountPeerLoginExpirationDisabledMessage is a human-readable text message of the AccountPeerLoginExpirationDisabled activity
AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account"
// AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity
AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated"
// PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity
PersonalAccessTokenCreatedMessage string = "Personal access token created"
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
// ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity
ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted"
// UserBlockedMessage is a human-readable text message of the UserBlocked activity
UserBlockedMessage string = "User blocked"
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
UserUnblockedMessage string = "User unblocked"
// GroupDeletedMessage is a human-readable text message of the GroupDeleted activity
GroupDeletedMessage string = "Group deleted"
)
// Activity that triggered an Event
type Activity int
// Message returns a string representation of an activity
func (a Activity) Message() string {
switch a {
case PeerAddedByUser:
return PeerAddedByUserMessage
case PeerRemovedByUser:
return PeerRemovedByUserMessage
case PeerAddedWithSetupKey:
return PeerAddedWithSetupKeyMessage
case UserJoined:
return UserJoinedMessage
case UserInvited:
return UserInvitedMessage
case AccountCreated:
return AccountCreatedMessage
case RuleAdded:
return RuleAddedMessage
case RuleRemoved:
return RuleRemovedMessage
case RuleUpdated:
return RuleUpdatedMessage
case PolicyAdded:
return PolicyAddedMessage
case PolicyRemoved:
return PolicyRemovedMessage
case PolicyUpdated:
return PolicyUpdatedMessage
case SetupKeyCreated:
return SetupKeyCreatedMessage
case SetupKeyUpdated:
return SetupKeyUpdatedMessage
case SetupKeyRevoked:
return SetupKeyRevokedMessage
case SetupKeyOverused:
return SetupKeyOverusedMessage
case GroupCreated:
return GroupCreatedMessage
case GroupUpdated:
return GroupUpdatedMessage
case GroupAddedToPeer:
return GroupAddedToPeerMessage
case GroupRemovedFromPeer:
return GroupRemovedFromPeerMessage
case GroupRemovedFromUser:
return GroupRemovedFromUserMessage
case GroupAddedToUser:
return GroupAddedToUserMessage
case UserRoleUpdated:
return UserRoleUpdatedMessage
case GroupAddedToSetupKey:
return GroupAddedToSetupKeyMessage
case GroupRemovedFromSetupKey:
return GroupRemovedFromSetupKeyMessage
case GroupAddedToDisabledManagementGroups:
return GroupAddedToDisabledManagementGroupsMessage
case GroupRemovedFromDisabledManagementGroups:
return GroupRemovedFromDisabledManagementGroupsMessage
case RouteCreated:
return RouteCreatedMessage
case RouteRemoved:
return RouteRemovedMessage
case RouteUpdated:
return RouteUpdatedMessage
case PeerSSHEnabled:
return PeerSSHEnabledMessage
case PeerSSHDisabled:
return PeerSSHDisabledMessage
case PeerLoginExpirationEnabled:
return PeerLoginExpirationEnabledMessage
case PeerLoginExpirationDisabled:
return PeerLoginExpirationDisabledMessage
case PeerRenamed:
return PeerRenamedMessage
case NameserverGroupCreated:
return NameserverGroupCreatedMessage
case NameserverGroupDeleted:
return NameserverGroupDeletedMessage
case NameserverGroupUpdated:
return NameserverGroupUpdatedMessage
case AccountPeerLoginExpirationEnabled:
return AccountPeerLoginExpirationEnabledMessage
case AccountPeerLoginExpirationDisabled:
return AccountPeerLoginExpirationDisabledMessage
case AccountPeerLoginExpirationDurationUpdated:
return AccountPeerLoginExpirationDurationUpdatedMessage
case PersonalAccessTokenCreated:
return PersonalAccessTokenCreatedMessage
case PersonalAccessTokenDeleted:
return PersonalAccessTokenDeletedMessage
case ServiceUserCreated:
return ServiceUserCreatedMessage
case ServiceUserDeleted:
return ServiceUserDeletedMessage
case UserBlocked:
return UserBlockedMessage
case UserUnblocked:
return UserUnblockedMessage
case GroupDeleted:
return GroupDeletedMessage
default:
return "UNKNOWN_ACTIVITY"
}
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity
func (a Activity) StringCode() string { func (a Activity) StringCode() string {
switch a { if code, ok := activityMap[a]; ok {
case PeerAddedByUser: return code.code
return "user.peer.add"
case PeerRemovedByUser:
return "user.peer.delete"
case PeerAddedWithSetupKey:
return "setupkey.peer.add"
case UserJoined:
return "user.join"
case UserInvited:
return "user.invite"
case UserBlocked:
return "user.block"
case UserUnblocked:
return "user.unblock"
case AccountCreated:
return "account.create"
case RuleAdded:
return "rule.add"
case RuleRemoved:
return "rule.delete"
case RuleUpdated:
return "rule.update"
case PolicyAdded:
return "policy.add"
case PolicyRemoved:
return "policy.delete"
case PolicyUpdated:
return "policy.update"
case SetupKeyCreated:
return "setupkey.add"
case SetupKeyRevoked:
return "setupkey.revoke"
case SetupKeyOverused:
return "setupkey.overuse"
case SetupKeyUpdated:
return "setupkey.update"
case GroupCreated:
return "group.add"
case GroupUpdated:
return "group.update"
case GroupDeleted:
return "group.delete"
case GroupRemovedFromPeer:
return "peer.group.delete"
case GroupAddedToPeer:
return "peer.group.add"
case GroupAddedToUser:
return "user.group.add"
case GroupRemovedFromUser:
return "user.group.delete"
case UserRoleUpdated:
return "user.role.update"
case GroupAddedToSetupKey:
return "setupkey.group.add"
case GroupRemovedFromSetupKey:
return "setupkey.group.delete"
case GroupAddedToDisabledManagementGroups:
return "dns.setting.disabled.management.group.add"
case GroupRemovedFromDisabledManagementGroups:
return "dns.setting.disabled.management.group.delete"
case RouteCreated:
return "route.add"
case RouteRemoved:
return "route.delete"
case RouteUpdated:
return "route.update"
case PeerRenamed:
return "peer.rename"
case PeerSSHEnabled:
return "peer.ssh.enable"
case PeerSSHDisabled:
return "peer.ssh.disable"
case PeerLoginExpirationDisabled:
return "peer.login.expiration.disable"
case PeerLoginExpirationEnabled:
return "peer.login.expiration.enable"
case NameserverGroupCreated:
return "nameserver.group.add"
case NameserverGroupDeleted:
return "nameserver.group.delete"
case NameserverGroupUpdated:
return "nameserver.group.update"
case AccountPeerLoginExpirationDurationUpdated:
return "account.setting.peer.login.expiration.update"
case AccountPeerLoginExpirationEnabled:
return "account.setting.peer.login.expiration.enable"
case AccountPeerLoginExpirationDisabled:
return "account.setting.peer.login.expiration.disable"
case PersonalAccessTokenCreated:
return "personal.access.token.create"
case PersonalAccessTokenDeleted:
return "personal.access.token.delete"
case ServiceUserCreated:
return "service.user.create"
case ServiceUserDeleted:
return "service.user.delete"
default:
return "UNKNOWN_ACTIVITY"
} }
return "UNKNOWN_ACTIVITY"
}
// Message returns a string representation of an activity
func (a Activity) Message() string {
if code, ok := activityMap[a]; ok {
return code.message
}
return "UNKNOWN_ACTIVITY"
} }

View File

@@ -4,6 +4,10 @@ import (
"time" "time"
) )
const (
SystemInitiator = "sys"
)
// Event represents a network/system activity event. // Event represents a network/system activity event.
type Event struct { type Event struct {
// Timestamp of the event // Timestamp of the event
@@ -14,10 +18,13 @@ type Event struct {
ID uint64 ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user) // InitiatorID is the ID of an object that initiated the event (e.g., a user)
InitiatorID string InitiatorID string
// InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only
InitiatorEmail string
// TargetID is the ID of an object that was effected by the event (e.g., a peer) // TargetID is the ID of an object that was effected by the event (e.g., a peer)
TargetID string TargetID string
// AccountID is the ID of an account where the event happened // AccountID is the ID of an account where the event happened
AccountID string AccountID string
// Meta of the event, e.g. deleted peer information like name, IP, etc // Meta of the event, e.g. deleted peer information like name, IP, etc
Meta map[string]any Meta map[string]any
} }
@@ -31,12 +38,13 @@ func (e *Event) Copy() *Event {
} }
return &Event{ return &Event{
Timestamp: e.Timestamp, Timestamp: e.Timestamp,
Activity: e.Activity, Activity: e.Activity,
ID: e.ID, ID: e.ID,
InitiatorID: e.InitiatorID, InitiatorID: e.InitiatorID,
TargetID: e.TargetID, InitiatorEmail: e.InitiatorEmail,
AccountID: e.AccountID, TargetID: e.TargetID,
Meta: meta, AccountID: e.AccountID,
Meta: meta,
} }
} }

View File

@@ -0,0 +1,81 @@
package sqlite
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type EmailEncrypt struct {
block cipher.Block
}
func GenerateKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
return "", err
}
readableKey := base64.StdEncoding.EncodeToString(key)
return readableKey, nil
}
func NewEmailEncrypt(key string) (*EmailEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(binKey)
if err != nil {
return nil, err
}
ec := &EmailEncrypt{
block: block,
}
return ec, nil
}
func (ec *EmailEncrypt) Encrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText)
}
func (ec *EmailEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, cipherText)
payload, err := pkcs5UnPadding(cipherText)
if err != nil {
return "", err
}
return string(payload), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -0,0 +1,63 @@
package sqlite
import (
"testing"
)
func TestGenerateKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewEmailEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.Decrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewEmailEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
newKey, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err = NewEmailEncrypt(newKey)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
res, err := ee.Decrypt(encrypted)
if err == nil || res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}

View File

@@ -4,12 +4,13 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/activity"
// sqlite driver
_ "github.com/mattn/go-sqlite3"
"path/filepath" "path/filepath"
"time" "time"
_ "github.com/mattn/go-sqlite3" // sqlite driver
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
) )
const ( const (
@@ -24,44 +25,106 @@ const (
"meta TEXT," + "meta TEXT," +
" target_id TEXT);" " target_id TEXT);"
selectStatement = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);`
" FROM events WHERE account_id = ? ORDER BY timestamp %s LIMIT ? OFFSET ?;"
insertStatement = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ?
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
"VALUES(?, ?, ?, ?, ?, ?)" "VALUES(?, ?, ?, ?, ?, ?)"
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)`
) )
// Store is the implementation of the activity.Store interface backed by SQLite // Store is the implementation of the activity.Store interface backed by SQLite
type Store struct { type Store struct {
db *sql.DB db *sql.DB
emailEncrypt *EmailEncrypt
insertStatement *sql.Stmt
selectAscStatement *sql.Stmt
selectDescStatement *sql.Stmt
deleteUserStmt *sql.Stmt
} }
// NewSQLiteStore creates a new Store with an event table if not exists. // NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string) (*Store, error) { func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB) dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile) db, err := sql.Open("sqlite3", dbFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
crypt, err := NewEmailEncrypt(encryptionKey)
if err != nil {
return nil, err
}
_, err = db.Exec(createTableQuery) _, err = db.Exec(createTableQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Store{db: db}, nil _, err = db.Exec(creatTableAccountEmailQuery)
if err != nil {
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
return nil, err
}
s := &Store{
db: db,
emailEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
} }
func processResult(result *sql.Rows) ([]*activity.Event, error) { func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0) events := make([]*activity.Event, 0)
for result.Next() { for result.Next() {
var id int64 var id int64
var operation activity.Activity var operation activity.Activity
var timestamp time.Time var timestamp time.Time
var initiator string var initiator string
var initiatorEmail *string
var target string var target string
var targetEmail *string
var account string var account string
var jsonMeta string var jsonMeta string
err := result.Scan(&id, &operation, &timestamp, &initiator, &target, &account, &jsonMeta) err := result.Scan(&id, &operation, &timestamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -74,7 +137,17 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
} }
} }
events = append(events, &activity.Event{ if targetEmail != nil {
email, err := store.emailEncrypt.Decrypt(*targetEmail)
if err != nil {
log.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = ""
} else {
meta["email"] = email
}
}
event := &activity.Event{
Timestamp: timestamp, Timestamp: timestamp,
Activity: operation, Activity: operation,
ID: uint64(id), ID: uint64(id),
@@ -82,7 +155,18 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
TargetID: target, TargetID: target,
AccountID: account, AccountID: account,
Meta: meta, Meta: meta,
}) }
if initiatorEmail != nil {
email, err := store.emailEncrypt.Decrypt(*initiatorEmail)
if err != nil {
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
} else {
event.InitiatorEmail = email
}
}
events = append(events, event)
} }
return events, nil return events, nil
@@ -90,13 +174,9 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp // Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
order := "DESC" stmt := store.selectDescStatement
if !descending { if !descending {
order = "ASC" stmt = store.selectAscStatement
}
stmt, err := store.db.Prepare(fmt.Sprintf(selectStatement, order))
if err != nil {
return nil, err
} }
result, err := stmt.Query(accountID, limit, offset) result, err := stmt.Query(accountID, limit, offset)
@@ -105,19 +185,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
} }
defer result.Close() //nolint defer result.Close() //nolint
return processResult(result) return store.processResult(result)
} }
// Save an event in the SQLite events table // Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) { func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
var jsonMeta string
stmt, err := store.db.Prepare(insertStatement) meta, err := store.saveDeletedUserEmailInEncrypted(event)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var jsonMeta string if meta != nil {
if event.Meta != nil {
metaBytes, err := json.Marshal(event.Meta) metaBytes, err := json.Marshal(event.Meta)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -125,7 +204,7 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
jsonMeta = string(metaBytes) jsonMeta = string(metaBytes)
} }
result, err := stmt.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta) result, err := store.insertStatement.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -140,6 +219,29 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
return eventCopy, nil return eventCopy, nil
} }
// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from
// meta map
func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) {
email, ok := event.Meta["email"]
if !ok {
return event.Meta, nil
}
delete(event.Meta, "email")
encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email))
_, err := store.deleteUserStmt.Exec(event.TargetID, encrypted)
if err != nil {
return nil, err
}
if len(event.Meta) == 1 {
return nil, nil // nolint
}
delete(event.Meta, "email")
return event.Meta, nil
}
// Close the Store // Close the Store
func (store *Store) Close() error { func (store *Store) Close() error {
if store.db != nil { if store.db != nil {

View File

@@ -12,7 +12,8 @@ import (
func TestNewSQLiteStore(t *testing.T) { func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir() dataDir := t.TempDir()
store, err := NewSQLiteStore(dataDir) key, _ := GenerateKey()
store, err := NewSQLiteStore(dataDir, key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

Some files were not shown because too many files have changed in this diff Show More