Compare commits

...

51 Commits

Author SHA1 Message Date
Maycon Santos
fa4b8c1d42 Update ephemeral field on the API response (#1129) 2023-09-06 10:40:45 +02:00
Maycon Santos
7682fe2e45 Account ephemeral setup keys metrics (#1128) 2023-09-05 23:04:14 +02:00
Zoltan Papp
c9b2ce08eb DNS forwarder and common ebpf loader (#1083)
In case the 53 UDP port is not an option to bind then we hijack the DNS traffic with eBPF, and we forward the traffic to the listener on a custom port. With this implementation, we should be able to listen to DNS queries on any address and still set the local host system to send queries to the custom address on port 53.

Because we tried to attach multiple XDP programs to the same interface, I did a refactor in the WG traffic forward code also.
2023-09-05 21:14:02 +02:00
Givi Khojanashvili
246abda46d Add default firewall rule to allow netbird traffic (#1056)
Add a default firewall rule to allow netbird traffic to be handled 
by the access control managers.

Userspace manager behavior:
- When running on Windows, a default rule is add on Windows firewall
- For Linux, we are using one of the Kernel managers to add a single rule
- This PR doesn't handle macOS

Kernel manager behavior:
- For NFtables, if there is a filter table, an INPUT rule is added
- Iptables follows the previous flow if running on kernel mode. If running 
on userspace mode, it adds a single rule for INPUT and OUTPUT chains

A new checkerFW package has been introduced to consolidate checks across
route and access control managers.
It supports a new environment variable to skip nftables and allow iptables tests
2023-09-05 21:07:32 +02:00
Misha Bragin
e4bc76c4de Ignore empty fields in the app metadata when storing on IDP (#1122) 2023-09-05 14:41:50 +02:00
Maycon Santos
bdb8383485 Use github token to read api (#1125)
prevent failing tests by using a github 
token to perform requests in our CI/CD
2023-09-05 14:40:40 +02:00
Yury Gargay
bb40325977 Update GitHub Actions and Enhance golangci-lint (#1075)
This PR showcases the implementation of additional linter rules. I've updated the golangci-lint GitHub Actions to the latest available version. This update makes sure that the tool works the same way locally - assuming being updated regularly - and with the GitHub Actions.

I've also taken care of keeping all the GitHub Actions up to date, which helps our code stay current. But there's one part, goreleaser that's a bit tricky to test on our computers. So, it's important to take a close look at that.

To make it easier to understand what I've done, I've made separate changes for each thing that the new linters found. This should help the people reviewing the changes see what's going on more clearly. Some of the changes might not be obvious at first glance.

Things to consider for the future
CI runs on Ubuntu so the static analysis only happens for Linux. Consider running it for the rest: Darwin, Windows
2023-09-04 17:03:44 +02:00
Fábio C. Barrionuevo da Luz
8524cc75d6 Add safe security headers (#1121)
This pull-request add/changes the HTTP headers to include 
safe defaults to Caddy and get the A+ score on 
the https://observatory.mozilla.org/ test
2023-09-04 15:49:07 +02:00
Zoltan Papp
c1f164c9cb Feature/ephemeral peers (#1100)
The ephemeral manager keep the inactive ephemeral peers in a linked list. The manager schedule a cleanup procedure to the head of the linked list (to the most deprecated peer). At the end of cleanup schedule the next cleanup to the new head.
If a device connect back to the server the manager will remote it from the peers list.
2023-09-04 11:37:39 +02:00
Maycon Santos
4e2d075413 Add Wix file for MSI builds (#1099)
This adds a basic wxs file to build MSI installer

This file was created using docs 
from https://wixtoolset.org/docs/schema/wxs/ and 
examples from gsudo, qemu-shoggoth, and many others.

The main difference between this and the .exe installer
is that we don't use the netbird service command to install
the daemon
2023-09-04 11:15:39 +02:00
pascal-fischer
f89c200ce9 Fix api Auth with PAT when a custom UserIDClaim is configured in management.json (#1120)
The API authentication with PATs was not considering different userIDClaim 
that some of the IdPs are using.
In this PR we read the userIDClaim from the config file 
instead of using the fixed default and only keep 
it as a fallback if none in defined.
2023-09-01 18:09:59 +02:00
Misha Bragin
d51dc4fd33 Add sharedsock example (#1116) 2023-08-31 17:01:32 +02:00
Zoltan Papp
00dddb9458 Fix log formatter initialization in mgm cmd (#1112)
The log format was mixed in the management command.
In this commit put to earlier state the log preparation.
2023-08-30 11:42:03 +02:00
Bethuel Mmbaga
1a9301b684 Close PKCE Listening Port After Authorization (#1110)
Addresses the issue of an open listening port persisting 
after the PKCE authorization flow is completed.
2023-08-29 09:13:27 +02:00
Bethuel Mmbaga
80d9b5fca5 Add auto-update feature in netbird script for binary installation (#1106)
This pull request addresses the need to enhance the installer script by introducing a new parameter --update to trigger updates. The goal is to streamline the update process for binary installations and provide a better experience for users.
2023-08-28 16:21:04 +02:00
Bethuel Mmbaga
ac0b7dc8cb Enhance linux client authentication (#1093)
The change clarifies the message usage, 
indicating that setup keys can alternatively be used 
in the authentication process. 
This approach adds flexibility in scenarios 
where automated authentication is unachievable, 
especially in non-desktop Linux environments.
2023-08-23 20:03:34 +02:00
Yury Gargay
e586eca16c Improve account copying (#1069)
With this fix, all nested slices and pointers will be copied by value.
Also, this fixes tests to compare the original and copy account by their
values by marshaling them to JSON strings.

Before that, they were copying the pointers that also passed the simple `=` compassion
(as the addresses match).
2023-08-22 17:56:39 +02:00
Misha Bragin
892db25021 docs: change get started link (#1098) 2023-08-21 09:11:52 +02:00
pascal-fischer
da75a76d41 Adding dashboard login activity (#1092)
For better auditing this PR adds a dashboard login event to the management service.

For that the user object was extended with a field for last login that is not actively saved to the database but kept in memory until next write. The information about the last login can be extracted from the JWT claims nb_last_login. This timestamp will be stored and compared on each API request. If the value changes we generate an event to inform about a login.
2023-08-18 19:23:11 +02:00
Givi Khojanashvili
3ac32fd78a Send network update when propagate user auto-groups (#1084)
For peer propagation this commit triggers
network map update in two cases:
  1) peer login
  2) user AutoGroups update

Also it issues new activity message about new user group
for peer login process.

Previous implementation only adds JWT groups to user. This fix also
removes JWT groups from user auto assign groups.

Pelase note, it also happen when user works with dashboard.
2023-08-18 15:36:05 +02:00
Bethuel Mmbaga
3aa657599b Switch OAuth flow initialization order (#1089)
Switches the order of initialization in the OAuth flow within 
the NewOAuthFlow method. Instead of initializing the 
Device Authorization Flow first, it now initializes 
the PKCE Authorization Flow first, and falls back 
to the Device Authorization Flow if the PKCE initialization fails.
2023-08-17 14:10:03 +02:00
Misha Bragin
d4e9087f94 Add peer login and expiration activity events (#1090)
Track the even of a user logging in their peer.
Track the event of a peer login expiration.
2023-08-17 14:04:04 +02:00
Zoltan Papp
da8447a67d Update the link to the doc page (#1088) 2023-08-17 12:27:04 +02:00
Misha Bragin
8e3bcd57a2 Specify invited by email when inviting a user (#1087) 2023-08-16 23:05:22 +02:00
Maycon Santos
4572c6c1f8 Avoid categorization on incoming claim (#1086)
This prevents domain categorization on claims of invited users
2023-08-16 16:11:26 +02:00
Maycon Santos
01f2b0ecb7 Add support to force using binary install (#1082)
Check if the USE_BIN_INSTALL variable is set to true and skip package manager discovery
2023-08-16 15:10:57 +02:00
Bethuel Mmbaga
442ba7cbc8 Add domain validation for nameserver groups (#1077)
This change ensures that domain names with uppercase 
letters are also considered valid, 
providing more flexibility in domain naming.
2023-08-16 11:25:38 +02:00
Maycon Santos
6c2b364966 Update client Dockerfile to use Alpine as base image and install necessary packages (#1078) 2023-08-12 16:12:09 +02:00
Zoltan Papp
0f0c7ec2ed Routemgr error handling (#1073)
In case the route management feature is not supported 
then do not create unnecessary firewall and manager instances. 
This can happen if the nftables nor iptables is not available on the host OS.

- Move the error handling to upper layer
- Remove fake, useless implementations of interfaces
- Update go-iptables because In Docker the old version can not 
determine well the path of executable file
- update lib to 0.70
2023-08-12 11:42:36 +02:00
Zoltan Papp
2dec016201 Fix/always on boot (#1062)
In case of 'always-on' feature has switched on, after the reboot the service do not start properly in all cases.
If the device is in offline state (no internet connection) the auth login steps will fail and the service will stop.
For the auth steps make no sense in this case because if the OS start the service we do not have option for
the user interaction.
2023-08-11 11:51:39 +02:00
Misha Bragin
06125acb8d Update new release banner (#1072) 2023-08-10 21:10:12 +02:00
Maycon Santos
a9b9b3fa0a Fix input reading for NetBird domain in getting-started-with-zitadel.sh (#1064) 2023-08-08 20:10:14 +02:00
Zoltan Papp
cdf57275b7 Rename eBPF program to reflect better to NetBird (#1063)
Rename program name and map name
2023-08-08 19:53:51 +02:00
Givi Khojanashvili
e5e69b1f75 Autopropagate peers by JWT groups (#1037)
Enhancements to Peer Group Assignment:

1. Auto-assigned groups are now applied to all peers every time a user logs into the network.
2. Feature activation is available in the account settings.
3. API modifications included to support these changes for account settings updates.
4. If propagation is enabled, updates to a user's auto-assigned groups are immediately reflected across all user peers.
5. With the JWT group sync feature active, auto-assigned groups are forcefully updated whenever a peer logs in using user credentials.
2023-08-07 19:44:51 +04:00
Zoltan Papp
8eca83f3cb Fix/ebpf free (#1057)
* Fix ebpf free call

* Add debug logs
2023-08-07 11:43:32 +02:00
Maycon Santos
973316d194 Validate input of expiration time for setup-keys (#1053)
So far we accepted any value for setup keys, including negative values

Now we are checking if it is less than 1 day or greater than 365 days
2023-08-04 23:54:51 +02:00
Zoltan Papp
a0a6ced148 After add listener automatically trigger peer list change event (#1044)
In case of alway-on start the peer list was invalid on Android UI.
2023-08-04 14:14:08 +02:00
Misha Bragin
0fc6c477a9 Add features links to the features table in README (#1052) 2023-08-04 11:52:11 +02:00
Misha Bragin
401a462398 Update getting started docs (#1049) 2023-08-04 11:05:05 +02:00
Zoltan Papp
a3839a6ef7 Fix error handling in iptables initialization (#1051)
* Fix error handling in iptables initialization

* Change log level
2023-08-03 22:12:36 +02:00
Maycon Santos
8aa4f240c7 Add getting started script with Zitadel (#1005)
add getting started script with zitadel

limit tests for infrastructure file workflow

limit release workflow based on relevant files
2023-08-03 19:19:17 +02:00
Zoltan Papp
d9686bae92 Handle conn store in thread safe way (#1047)
* Handle conn store in thread safe way

* Change log line

* Fix proper error handling
2023-08-03 18:24:23 +02:00
pascal-fischer
24e19ae287 revert systemd changes (#1046) 2023-08-03 00:05:13 +02:00
Maycon Santos
74fde0ea2c Update setup key auto_groups description (#1042)
* Update setup key auto_groups description

* Update setup key auto_groups description
2023-08-02 17:50:00 +02:00
pascal-fischer
890e09b787 Keep confiured nameservers as fallback (#1036)
* keep existing nameserver as fallback when adding netbird resolver

* fix resolvconf

* fix imports
2023-08-01 17:45:44 +02:00
Bethuel Mmbaga
48098c994d Handle authentication errors in PKCE flow (#1039)
* handle authentication errors in PKCE flow

* remove shadowing and replace TokenEndpoint for PKCE config

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2023-07-31 14:22:38 +02:00
Bethuel Mmbaga
64f6343fcc Add html screen for pkce flow (#1034)
* add html screen for pkce flow

* remove unused CSS classes in pkce-auth-msg.html

* remove links to external sources
2023-07-28 18:10:12 +02:00
Maycon Santos
24713fbe59 Move ebpf code to its own package to avoid crash issues in Android (#1033)
* Move ebpf code to its own package to avoid crash issues in Android

Older versions of android crashes because of the bytecode files
Even when they aren't loaded as it was our case

* move c file to own folder

* fix lint
2023-07-27 15:34:27 +02:00
Bethuel Mmbaga
7794b744f8 Add PKCE authorization flow (#1012)
Enhance the user experience by enabling authentication to Netbird using Single Sign-On (SSO) with any Identity Provider (IDP) provider. Current client offers this capability through the Device Authorization Flow, however, is not widely supported by many IDPs, and even some that do support it do not provide a complete verification URL.

To address these challenges, this pull request enable Authorization Code Flow with Proof Key for Code Exchange (PKCE) for client logins, which is a more widely adopted and secure approach to facilitate SSO with various IDP providers.
2023-07-27 11:31:07 +02:00
Maycon Santos
0d0c30c16d Avoid compiling linux NewFactory for Android (#1032) 2023-07-26 16:21:04 +02:00
Zoltan Papp
b0364da67c Wg ebpf proxy (#911)
EBPF proxy between TURN (relay) and WireGuard to reduce number of used ports used by the NetBird agent.
- Separate the wg configuration from the proxy logic
- In case if eBPF type proxy has only one single proxy instance
- In case if the eBPF is not supported fallback to the original proxy Implementation

Between the signature of eBPF type proxy and original proxy has 
differences so this is why the factory structure exists
2023-07-26 14:00:47 +02:00
188 changed files with 6473 additions and 2230 deletions

View File

@@ -15,14 +15,14 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -18,13 +18,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -32,7 +32,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
@@ -47,13 +47,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -61,7 +61,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev

View File

@@ -8,14 +8,13 @@ jobs:
name: lint name: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - name: Checkout code
uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v3
with:
args: --timeout=6m

View File

@@ -0,0 +1,36 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -1,60 +0,0 @@
name: Test installation Darwin
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
env:
SKIP_UI_APP: true
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
install-all:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
echo "Error: NetBird UI is not installed"
exit 1
fi

View File

@@ -1,38 +0,0 @@
name: Test installation Linux
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest
strategy:
matrix:
check_bin_install: [true, false]
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename apt package
if: ${{ matrix.check_bin_install }}
run: |
sudo mv /usr/bin/apt /usr/bin/apt.bak
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi

View File

@@ -7,9 +7,19 @@ on:
branches: branches:
- main - main
pull_request: pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.goreleaser.yml'
- '.goreleaser_ui.yaml'
- '.goreleaser_ui_darwin.yaml'
- '.github/workflows/release.yml'
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
env: env:
SIGN_PIPE_VER: "v0.0.8" SIGN_PIPE_VER: "v0.0.9"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v1.14.1"
concurrency: concurrency:
@@ -19,20 +29,24 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
flags: ""
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20" go-version: "1.20"
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v1 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -46,10 +60,10 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- -
name: Set up QEMU name: Set up QEMU
uses: docker/setup-qemu-action@v1 uses: docker/setup-qemu-action@v2
- -
name: Set up Docker Buildx name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1 uses: docker/setup-buildx-action@v2
- -
name: Login to Docker hub name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
@@ -72,10 +86,10 @@ jobs:
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
- -
name: Run GoReleaser name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --rm-dist args: release --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -83,7 +97,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: release name: release
path: dist/ path: dist/
@@ -92,17 +106,19 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20" go-version: "1.20"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v1 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@@ -122,17 +138,17 @@ jobs:
- name: Generate windows rsrc - name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
@@ -141,19 +157,21 @@ jobs:
release_ui_darwin: release_ui_darwin:
runs-on: macos-11 runs-on: macos-11
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20" go-version: "1.20"
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v1 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
@@ -165,15 +183,15 @@ jobs:
- -
name: Run GoReleaser name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v2 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/

View File

@@ -1,18 +1,20 @@
name: Test Docker Compose Linux name: Test Infrastructure files
on: on:
push: push:
branches: branches:
- main - main
pull_request: pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
test: test-docker-compose:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install jq - name: Install jq
@@ -22,12 +24,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v2 uses: actions/setup-go@v4
with: with:
go-version: "1.20.x" go-version: "1.20.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -35,7 +37,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -69,6 +71,7 @@ jobs:
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers" CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken" CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email" CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
@@ -91,8 +94,8 @@ jobs:
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE" grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
grep UseIDToken management.json | grep false grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
@@ -100,6 +103,12 @@ jobs:
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files working-directory: infrastructure_files
@@ -114,3 +123,28 @@ jobs:
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4 test $count -eq 4
working-directory: infrastructure_files working-directory: infrastructure_files
test-getting-started-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v3
- name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen
run: test -f Caddyfile
- name: test docker-compose file gen
run: test -f docker-compose.yml
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen
run: test -f turnserver.conf
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen
run: test -f dashboard.env

1
.gitignore vendored
View File

@@ -19,3 +19,4 @@ client/.distfiles/
infrastructure_files/setup.env infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store

54
.golangci.yaml Normal file
View File

@@ -0,0 +1,54 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
- path: sharedsock/filter.go
linters:
- unused
- path: client/firewall/iptables/rule.go
linters:
- unused
- path: mock.go
linters:
- nilnil

View File

@@ -377,3 +377,13 @@ uploads:
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com username: dev@wiretrustee.com
method: PUT method: PUT
checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh

View File

@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
<strong>:hatching_chick: New Release! Peer expiration.</strong> <strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
<a href="https://github.com/netbirdio/netbird/releases"> <a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
Learn more Learn more
</a> </a>
</p> </p>
@@ -24,7 +24,7 @@
<p align="center"> <p align="center">
<strong> <strong>
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a> Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
@@ -36,47 +36,62 @@
<br> <br>
**NetBird is an open-source VPN management platform built on top of WireGuard® making it easy to create secure private networks for your organization or home.** **NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience. **Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
**Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes.
- \[x] Private DNS.
- \[x] Network Activity Monitoring.
**Coming soon:**
- \[ ] Mobile clients.
### Secure peer-to-peer VPN with SSO and MFA in minutes ### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
**Note**: The `main` branch may be in an *unstable or even broken state* during development. ### Key features
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Start using NetBird | Connectivity | Management | Automation | Platforms |
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/). |-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started). | <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide). | <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms. | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
- Web UI [repository](https://github.com/netbirdio/dashboard). | <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube. | <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[ ] iOS </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
| | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
| | <ul><li> - \[x] SSH access management </ul></li> | | |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
- Check NetBird [admin UI](https://app.netbird.io/).
- Add more machines.
### Quickstart with self-hosted NetBird
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
### A bit on NetBird internals ### A bit on NetBird internals
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard. - Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers). - Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
@@ -88,18 +103,18 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups. [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle"> <p float="left" align="middle">
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/> <img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
</p> </p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details. See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Roadmap
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
### Community projects ### Community projects
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird) - [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
- [NetBird installer script](https://github.com/physk/netbird-installer) - [NetBird installer script](https://github.com/physk/netbird-installer)
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Support acknowledgement ### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking. In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
@@ -107,7 +122,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png) ![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials ### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution). We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal ### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld. _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -18,10 +18,9 @@ func Encode(num uint32) string {
} }
var encoded strings.Builder var encoded strings.Builder
remainder := uint32(0)
for num > 0 { for num > 0 {
remainder = num % base remainder := num % base
encoded.WriteByte(alphabet[remainder]) encoded.WriteByte(alphabet[remainder])
num /= base num /= base
} }

View File

@@ -1,7 +1,5 @@
FROM gcr.io/distroless/base:debug FROM alpine:3
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
SHELL ["/busybox/sh","-c"]
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
ENTRYPOINT [ "/go/bin/netbird","up"] ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird COPY netbird /go/bin/netbird

View File

@@ -55,7 +55,6 @@ type Client struct {
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
routeListener routemanager.RouteListener routeListener routemanager.RouteListener
onHostDnsFn func([]string)
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -97,7 +96,30 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {} return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
c.ctxCancelLock.Unlock()
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
} }

View File

@@ -6,15 +6,14 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
) )
// SSOListener is async listener for mobile framework // SSOListener is async listener for mobile framework
@@ -86,10 +85,16 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
supportsSSO = false supportsSSO = false
err = nil err = nil
} }
return err
}
return err return err
}) })
@@ -183,27 +188,15 @@ func (a *Auth) login(urlOpener URLOpener) error {
return nil return nil
} }
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
if err != nil { if err != nil {
s, ok := gstatus.FromError(err) return nil, err
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
} }
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil { if err != nil {
return nil, fmt.Errorf("getting a request device code failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
go urlOpener.Open(flowInfo.VerificationURIComplete) go urlOpener.Open(flowInfo.VerificationURIComplete)
@@ -211,7 +204,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo,
waitTimeout := time.Duration(flowInfo.ExpiresIn) waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second) waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
defer cancel() defer cancel()
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo) tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err) return nil, fmt.Errorf("waiting for browser login failed: %v", err)
} }

View File

@@ -3,20 +3,21 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"runtime"
"strings" "strings"
"time" "time"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
) )
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
@@ -163,31 +164,15 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
if err != nil { if err != nil {
s, ok := gstatus.FromError(err) return nil, err
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
mgmtURL := managementURL
if mgmtURL == "" {
mgmtURL = internal.DefaultManagementURL
}
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your servver or use Setup Keys to login", mgmtURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
} }
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil { if err != nil {
return nil, fmt.Errorf("getting a request device code failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
@@ -196,7 +181,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second) waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
defer c() defer c()
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo) tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err) return nil, fmt.Errorf("waiting for browser login failed: %v", err)
} }
@@ -206,15 +191,64 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string var codeMsg string
if !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
} }
err := open.Run(verificationURIComplete) browserAuthMsg := "Please do the SSO login in your browser. \n" +
cmd.Printf("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" + "If your browser didn't open automatically, use this URL to log in:\n\n" +
" " + verificationURIComplete + " " + codeMsg + " \n\n") verificationURIComplete + " " + codeMsg
if err != nil {
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n") setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys"
authenticateUsingBrowser := func() {
cmd.Println(browserAuthMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println(setupKeyAuthMsg)
}
}
switch runtime.GOOS {
case "windows", "darwin":
authenticateUsingBrowser()
case "linux":
if isLinuxRunningDesktop() {
authenticateUsingBrowser()
} else {
// If current flow is PKCE, it implies the server is anticipating the redirect to localhost.
// Devices lacking browser support are incompatible with this flow.Therefore,
// these devices will need to resort to setup keys instead.
if isPKCEFlow(verificationURIComplete) {
cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} else {
cmd.Println(browserAuthMsg)
}
}
} }
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment.
func isLinuxRunningDesktop() bool {
for _, env := range os.Environ() {
values := strings.Split(env, "=")
if len(values) == 2 {
key, value := values[0], values[1]
if key == "XDG_CURRENT_DESKTOP" && value != "" {
return true
}
}
}
return false
}
// isPKCEFlow determines if the PKCE flow is active or not,
// by checking the existence of redirect_uri inside the verification URL.
func isPKCEFlow(verificationURL string) bool {
if verificationURL == "" {
return false
}
return strings.Contains(verificationURL, "redirect_uri")
}

View File

@@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
resp, _ := getStatus(ctx, cmd) resp, err := getStatus(ctx, cmd)
if err != nil { if err != nil {
return nil return err
} }
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
@@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
" netbird up \n\n"+ " netbird up \n\n"+
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+ "If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+ "you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n", "More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
resp.GetStatus(), resp.GetStatus(),
) )
return nil return nil
@@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
outputInformationHolder := convertToStatusOutputOverview(resp) outputInformationHolder := convertToStatusOutputOverview(resp)
statusOutputString := "" var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder) statusOutputString = parseToFullDetailSummary(outputInformationHolder)

View File

@@ -81,7 +81,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -40,6 +40,9 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set

View File

@@ -44,6 +44,7 @@ type Manager struct {
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() iface.WGAddress
IsUserspaceBind() bool
} }
type ruleset struct { type ruleset struct {
@@ -52,7 +53,7 @@ type ruleset struct {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
m := &Manager{ m := &Manager{
wgIface: wgIface, wgIface: wgIface,
inputDefaultRuleSpecs: []string{ inputDefaultRuleSpecs: []string{
@@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
rulesets: make(map[string]ruleset), rulesets: make(map[string]ruleset),
} }
if err := ipset.Init(); err != nil { err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err) return nil, fmt.Errorf("init ipset: %w", err)
} }
// init clients for booth ipv4 and ipv6 // init clients for booth ipv4 and ipv6
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
} }
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) if m.ipv4Client == nil && m.ipv6Client == nil {
if err != nil { return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
} }
if err := m.Reset(); err != nil { if err := m.Reset(); err != nil {
@@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment is empty rule ID is used as comment // If comment is empty rule ID is used as comment
@@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
return err
}
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err) return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
} }
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err) return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
} }

View File

@@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set") panic("AddressFunc is not set")
} }
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) { func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
@@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -29,6 +29,8 @@ const (
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter" FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@@ -379,7 +381,7 @@ func (m *Manager) chain(
if c != nil { if c != nil {
return c, nil return c, nil
} }
return m.createChainIfNotExists(tf, name, hook, priority, cType) return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
} }
if ip.To4() != nil { if ip.To4() != nil {
@@ -399,13 +401,20 @@ func (m *Manager) chain(
} }
// table returns the table for the given family of the IP address // table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 { if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil { if m.tableIPv4 != nil {
return m.tableIPv4, nil return m.tableIPv4, nil
} }
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4) table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil return m.tableIPv6, nil
} }
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6) table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil return m.tableIPv6, nil
} }
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family) tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == FilterTableName { if t.Name == tableName {
return t, nil return t, nil
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return nil, err return nil, err
} }
@@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
func (m *Manager) createChainIfNotExists( func (m *Manager) createChainIfNotExists(
family nftables.TableFamily, family nftables.TableFamily,
tableName string,
name string, name string,
hooknum nftables.ChainHook, hooknum nftables.ChainHook,
priority nftables.ChainPriority, priority nftables.ChainPriority,
chainType nftables.ChainType, chainType nftables.ChainType,
) (*nftables.Chain, error) { ) (*nftables.Chain, error) {
table, err := m.table(family) table, err := m.table(family, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c) m.rConn.DelChain(c)
} }
@@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
return nil return nil
} }
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) { func (m *Manager) flushWithBackoff() (err error) {
backoff := 4 backoff := 4
backoffTime := 1000 * time.Millisecond backoffTime := 1000 * time.Millisecond
@@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
return nil return nil
} }
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func encodePort(port fw.Port) []byte { func encodePort(port fw.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))

View File

@@ -0,0 +1,19 @@
//go:build !windows && !linux
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}

View File

@@ -0,0 +1,21 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@@ -0,0 +1,91 @@
package uspfilter
import (
"errors"
"fmt"
"os/exec"
"strings"
"syscall"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
noRulesMatchCriteria = "No rules match the specified criteria"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, args ...string) error {
active, err := isFirewallRuleActive(ruleName)
if err != nil {
return err
}
if (action == addRule && !active) || (action == deleteRule && active) {
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
args := append(baseArgs, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
return nil
}
func isFirewallRuleActive(ruleName string) (bool, error) {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
output, err := cmd.Output()
if err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// if the firewall rule is not active, we expect last exit code to be 1
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
if exitStatus == 1 {
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
}
}
return false, err
}
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
return true, nil
}

View File

@@ -19,6 +19,7 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
} }
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
@@ -30,6 +31,8 @@ type Manager struct {
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
mutex sync.RWMutex mutex sync.RWMutex
} }
@@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
}, },
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet),
wgIface: iface,
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return nil return nil
} }
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@@ -16,6 +16,7 @@ import (
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },

View File

@@ -146,12 +146,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
// if this rule is member of rule selection with more than DefaultIPsCountForSet // if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it // it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)] ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" { if ipset.name == "" {
d.ipsetCounter++ d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter) ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
} }
ipsetName = ipset.name ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) log.Errorf("failed to apply firewall rule: %+v, %v", r, err)

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
) )
@@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil return newDefaultManager(fm), nil
} }
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)

View File

@@ -7,27 +7,69 @@ import (
"github.com/netbirdio/netbird/client/firewall/iptables" "github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
// Create creates a firewall manager instance for the Linux // Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (manager *DefaultManager, err error) { func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
// use userspace packet filtering firewall // use userspace packet filtering firewall
if fm, err = uspfilter.Create(iface); err != nil { usfm, err := uspfilter.Create(iface)
if err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err) log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err return nil, err
} }
} else {
if fm, err = nftables.Create(iface); err != nil { // set kernel space firewall Reset as hook for userspace firewall
log.Debugf("failed to create nftables manager: %s", err) // manager Reset method, to clean up
// fallback to iptables if resetHookForUserspace != nil {
if fm, err = iptables.Create(iface); err != nil { usfm.SetResetHook(resetHookForUserspace)
log.Errorf("failed to create iptables manager: %s", err) }
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err return nil, err
} }
}
}
return newDefaultManager(fm), nil return newDefaultManager(fm), nil
} }

View File

@@ -1,11 +1,13 @@
package acl package acl
import ( import (
"net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true) ifaceMock.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo") ifaceMock.EXPECT().SetFilter(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any()) ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface) acl, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create ACL manager: %v", err)
return return
@@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true) ifaceMock.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo") ifaceMock.EXPECT().SetFilter(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any()) ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface) acl, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create ACL manager: %v", err)
return return

View File

@@ -0,0 +1,202 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/client/internal"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// HostedGrantType grant type for device flow on Hosted
const (
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
)
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
type RequestDeviceCodePayload struct {
Audience string `json:"audience"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Hosted token's response
type TokenRequestResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
TokenInfo
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
return &DeviceAuthorizationFlow{
providerConfig: config,
HTTPClient: httpClient,
}, nil
}
// GetClientID returns the provider client id
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
form.Add("audience", d.providerConfig.Audience)
form.Add("scope", d.providerConfig.Scope)
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := d.HTTPClient.Do(req)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
}
if res.StatusCode != 200 {
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
}
deviceCode := AuthFlowInfo{}
err = json.Unmarshal(body, &deviceCode)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
}
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := d.HTTPClient.Do(req)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
}
defer func() {
err := res.Body.Close()
if err != nil {
return
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
}
if res.StatusCode > 499 {
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
return tokenResponse, nil
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
UseIDToken: d.providerConfig.UseIDToken,
}
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, err
}
}
}

View File

@@ -1,17 +1,17 @@
package internal package auth
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/client/internal"
"github.com/stretchr/testify/require"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
) )
type mockHTTPClient struct { type mockHTTPClient struct {
@@ -53,7 +53,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingErrFunc require.ErrorAssertionFunc testingErrFunc require.ErrorAssertionFunc
expectedErrorMSG string expectedErrorMSG string
testingFunc require.ComparisonAssertionFunc testingFunc require.ComparisonAssertionFunc
expectedOut DeviceAuthInfo expectedOut AuthFlowInfo
expectedMSG string expectedMSG string
expectPayload string expectPayload string
} }
@@ -92,7 +92,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingFunc: require.EqualValues, testingFunc: require.EqualValues,
expectPayload: expectPayload, expectPayload: expectPayload,
} }
testCase4Out := DeviceAuthInfo{ExpiresIn: 10} testCase4Out := AuthFlowInfo{ExpiresIn: 10}
testCase4 := test{ testCase4 := test{
name: "Got Device Code", name: "Got Device Code",
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn), inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
@@ -113,8 +113,8 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError, err: testCase.inputReqError,
} }
hosted := Hosted{ deviceFlow := &DeviceAuthorizationFlow{
providerConfig: ProviderConfig{ providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience, Audience: expectedAudience,
ClientID: expectedClientID, ClientID: expectedClientID,
Scope: expectedScope, Scope: expectedScope,
@@ -125,7 +125,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
HTTPClient: &httpClient, HTTPClient: &httpClient,
} }
authInfo, err := hosted.RequestDeviceCode(context.TODO()) authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match") require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
@@ -145,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
inputMaxReqs int inputMaxReqs int
inputCountResBody string inputCountResBody string
inputTimeout time.Duration inputTimeout time.Duration
inputInfo DeviceAuthInfo inputInfo AuthFlowInfo
inputAudience string inputAudience string
testingErrFunc require.ErrorAssertionFunc testingErrFunc require.ErrorAssertionFunc
expectedErrorMSG string expectedErrorMSG string
@@ -155,7 +155,7 @@ func TestHosted_WaitToken(t *testing.T) {
expectPayload string expectPayload string
} }
defaultInfo := DeviceAuthInfo{ defaultInfo := AuthFlowInfo{
DeviceCode: "test", DeviceCode: "test",
ExpiresIn: 10, ExpiresIn: 10,
Interval: 1, Interval: 1,
@@ -278,8 +278,8 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody, countResBody: testCase.inputCountResBody,
} }
hosted := Hosted{ deviceFlow := DeviceAuthorizationFlow{
providerConfig: ProviderConfig{ providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience, Audience: testCase.inputAudience,
ClientID: clientID, ClientID: clientID,
TokenEndpoint: "test.hosted.com/token", TokenEndpoint: "test.hosted.com/token",
@@ -287,11 +287,12 @@ func TestHosted_WaitToken(t *testing.T) {
Scope: "openid", Scope: "openid",
UseIDToken: false, UseIDToken: false,
}, },
HTTPClient: &httpClient} HTTPClient: &httpClient,
}
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel() defer cancel()
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo) tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match") require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")

View File

@@ -0,0 +1,88 @@
package auth
import (
"context"
"fmt"
"net/http"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
)
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
type OAuthFlow interface {
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
// HTTPClient http client interface for API calls
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
type AuthFlowInfo struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// Claims used when validating the access token
type Claims struct {
Audience interface{} `json:"aud"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
func (t TokenInfo) GetTokenToUse() string {
if t.UseIDToken {
return t.IDToken
}
return t.AccessToken
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("loading pkce authorization flow info")
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
log.Debugf("loading pkce authorization flow info failed with error: %v", err)
log.Debugf("falling back to device authorization flow info")
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -0,0 +1,252 @@
package auth
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"html/template"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
const (
queryState = "state"
queryCode = "code"
queryError = "error"
queryErrorDesc = "error_description"
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
for _, redirectURL := range config.RedirectURLs {
if !isRedirectURLPortUsed(redirectURL) {
availableRedirectURL = redirectURL
break
}
}
if availableRedirectURL == "" {
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
}
cfg := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthorizationEndpoint,
TokenURL: config.TokenEndpoint,
},
RedirectURL: availableRedirectURL,
Scopes: strings.Split(config.Scope, " "),
}
return &PKCEAuthorizationFlow{
providerConfig: config,
oAuthConfig: cfg,
}, nil
}
// GetClientID returns the provider client id
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
return p.providerConfig.ClientID
}
// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
}
p.state = state
codeVerifier, err := randomBytesInHex(64)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
}
p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
)
return AuthFlowInfo{
VerificationURIComplete: authURL,
ExpiresIn: defaultPKCETimeoutSeconds,
}, nil
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
go p.startServer(tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.handleOAuthToken(token)
case err := <-errChan:
return TokenInfo{}, err
}
}
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
var wg sync.WaitGroup
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
return
}
server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
}()
wg.Add(1)
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
defer wg.Done()
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
if err != nil {
renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
return
}
renderPKCEFlowTmpl(w, nil)
tokenChan <- token
})
wg.Wait()
if err := server.Shutdown(context.Background()); err != nil {
log.Errorf("error while shutting down pkce flow server: %v", err)
}
}
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenType: token.TokenType,
ExpiresIn: token.Expiry.Second(),
UseIDToken: p.providerConfig.UseIDToken,
}
if idToken, ok := token.Extra("id_token").(string); ok {
tokenInfo.IDToken = idToken
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, nil
}
func createCodeChallenge(codeVerifier string) string {
sha2 := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(sha2[:])
}
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
func isRedirectURLPortUsed(redirectURL string) bool {
parsedURL, err := url.Parse(redirectURL)
if err != nil {
log.Errorf("failed to parse redirect URL: %v", err)
return true
}
addr := fmt.Sprintf(":%s", parsedURL.Port())
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("error while closing the connection: %v", err)
}
}()
return true
}
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data := make(map[string]string)
if authError != nil {
data["Error"] = authError.Error()
}
if err := tmpl.Execute(w, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"reflect"
"strings"
)
func randomBytesInHex(count int) (string, error) {
buf := make([]byte, count)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
}
return hex.EncodeToString(buf), nil
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
return fmt.Errorf("token received is empty")
}
encodedClaims := strings.Split(token, ".")[1]
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
if err != nil {
return err
}
claims := Claims{}
err = json.Unmarshal(claimsString, &claims)
if err != nil {
return err
}
if claims.Audience == nil {
return fmt.Errorf("required token field audience is absent")
}
// Audience claim of JWT can be a string or an array of strings
typ := reflect.TypeOf(claims.Audience)
switch typ.Kind() {
case reflect.String:
if claims.Audience == audience {
return nil
}
case reflect.Slice:
for _, aud := range claims.Audience.([]interface{}) {
if audience == aud {
return nil
}
}
}
return fmt.Errorf("invalid JWT token audience field")
}

View File

@@ -0,0 +1,3 @@
//go:build !linux
package checkfw

View File

@@ -0,0 +1,56 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL) assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL) assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
if err != nil {
return
}
managementURL := "https://test.management.url:33071" managementURL := "https://test.management.url:33071"
adminURL := "https://app.admin.url:443" adminURL := "https://app.admin.url:443"
path := filepath.Join(t.TempDir(), "config.json") path := filepath.Join(t.TempDir(), "config.json")

View File

@@ -179,8 +179,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Print("Netbird engine started, my IP is: ", peerConfig.Address) log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected) state.Set(StatusConnected)
statusRecorder.ClientStart()
<-engineCtx.Done() <-engineCtx.Done()
statusRecorder.ClientTeardown() statusRecorder.ClientTeardown()
@@ -201,6 +199,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return nil return nil
} }
statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff) err = backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)

View File

@@ -16,11 +16,11 @@ import (
// DeviceAuthorizationFlow represents Device Authorization Flow information // DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct { type DeviceAuthorizationFlow struct {
Provider string Provider string
ProviderConfig ProviderConfig ProviderConfig DeviceAuthProviderConfig
} }
// ProviderConfig has all attributes needed to initiate a device authorization flow // DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type ProviderConfig struct { type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id // ClientID An IDP application client id
ClientID string ClientID string
// ClientSecret An IDP application client secret // ClientSecret An IDP application client secret
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
deviceAuthorizationFlow := DeviceAuthorizationFlow{ deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(), Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: ProviderConfig{ ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(), Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(), ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(), ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
deviceAuthorizationFlow.ProviderConfig.Scope = "openid" deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
} }
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil { if err != nil {
return DeviceAuthorizationFlow{}, err return DeviceAuthorizationFlow{}, err
} }
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
return deviceAuthorizationFlow, nil return deviceAuthorizationFlow, nil
} }
func isProviderConfigValid(config ProviderConfig) error { func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" { if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience") return fmt.Errorf(errorMSGFormat, "Audience")

View File

@@ -15,7 +15,8 @@ const (
fileGeneratedResolvConfSearchBeginContent = "search " fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader + fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n" fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
"%s\n"
) )
const ( const (
@@ -91,7 +92,12 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
searchDomains += " " + dConf.domain searchDomains += " " + dConf.domain
appendedDomains++ appendedDomains++
} }
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Errorf("Could not read existing resolv.conf")
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms) err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
if err != nil { if err != nil {
err = f.restore() err = f.restore()

View File

@@ -182,12 +182,11 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
} }
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey := s.getPrimaryService() primaryServiceKey, existingNameserver := s.getPrimaryService()
if primaryServiceKey == "" { if primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key") return fmt.Errorf("couldn't find the primary service key")
} }
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
if err != nil { if err != nil {
return err return err
} }
@@ -196,27 +195,32 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
return nil return nil
} }
func (s *systemConfigurator) getPrimaryService() string { func (s *systemConfigurator) getPrimaryService() (string, string) {
line := buildCommandLine("show", globalIPv4State, "") line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line) stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands) b, err := runSystemConfigCommand(stdinCommands)
if err != nil { if err != nil {
log.Error("got error while sending the command: ", err) log.Error("got error while sending the command: ", err)
return "" return "", ""
} }
scanner := bufio.NewScanner(bytes.NewReader(b)) scanner := bufio.NewScanner(bytes.NewReader(b))
primaryService := ""
router := ""
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
if strings.Contains(text, "PrimaryService") { if strings.Contains(text, "PrimaryService") {
return strings.TrimSpace(strings.Split(text, ":")[1]) primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
}
if strings.Contains(text, "Router") {
router = strings.TrimSpace(strings.Split(text, ":")[1])
} }
} }
return "" return primaryService, router
} }
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error { func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0)) lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines) addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand) stdinCommands := wrapCommand(addDomainCommand)

View File

@@ -4,6 +4,7 @@ package dns
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"strings" "strings"
@@ -59,7 +60,11 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
appendedDomains++ appendedDomains++
} }
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains) originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Errorf("Could not read existing resolv.conf")
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
err = r.applyConfig(content) err = r.applyConfig(content)
if err != nil { if err != nil {

View File

@@ -238,7 +238,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate := s.currentConfig hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.routeAll = false hostUpdate.routeAll = false
} }

View File

@@ -777,7 +777,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
newNet, err := stdnet.NewNet(nil) newNet, err := stdnet.NewNet(nil)
if err != nil { if err != nil {
t.Fatalf("create stdnet: %v", err) t.Fatalf("create stdnet: %v", err)
return nil, nil return nil, err
} }
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)

View File

@@ -11,6 +11,9 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
) )
const ( const (
@@ -24,10 +27,11 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
customAddr *netip.AddrPort customAddr *netip.AddrPort
server *dns.Server server *dns.Server
runtimeIP string listenIP string
runtimePort int listenPort int
listenerIsRunning bool listenerIsRunning bool
listenerFlagLock sync.Mutex listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
} }
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener { func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
@@ -43,6 +47,7 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service
UDPSize: 65535, UDPSize: 65535,
}, },
} }
return s return s
} }
@@ -55,13 +60,21 @@ func (s *serviceViaListener) Listen() error {
} }
var err error var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress() s.listenIP, s.listenPort, err = s.evalListenAddress()
if err != nil { if err != nil {
log.Errorf("failed to eval runtime address: %s", err) log.Errorf("failed to eval runtime address: %s", err)
return err return err
} }
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
if s.shouldApplyPortFwd() {
s.ebpfService = ebpf.GetEbpfManagerInstance()
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
if err != nil {
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
s.ebpfService = nil
}
}
log.Debugf("starting dns on %s", s.server.Addr) log.Debugf("starting dns on %s", s.server.Addr)
go func() { go func() {
s.setListenerStatus(true) s.setListenerStatus(true)
@@ -69,9 +82,10 @@ func (s *serviceViaListener) Listen() error {
err := s.server.ListenAndServe() err := s.server.ListenAndServe()
if err != nil { if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
} }
}() }()
return nil return nil
} }
@@ -90,6 +104,13 @@ func (s *serviceViaListener) Stop() {
if err != nil { if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err) log.Errorf("stopping dns server listener returned an error: %v", err)
} }
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
} }
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -101,11 +122,18 @@ func (s *serviceViaListener) DeregisterMux(pattern string) {
} }
func (s *serviceViaListener) RuntimePort() int { func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil {
return defaultPort
} else {
return s.listenPort
}
} }
func (s *serviceViaListener) RuntimeIP() string { func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP return s.listenIP
} }
func (s *serviceViaListener) setListenerStatus(running bool) { func (s *serviceViaListener) setListenerStatus(running bool) {
@@ -136,10 +164,30 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
} }
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) { func (s *serviceViaListener) evalListenAddress() (string, int, error) {
if s.customAddr != nil { if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
} }
return s.getFirstListenerAvailable() return s.getFirstListenerAvailable()
} }
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
// resolution won't work.
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
// traffic on port 53 and forward it to a local DNS server running on 5053.
func (s *serviceViaListener) shouldApplyPortFwd() bool {
if runtime.GOOS != "linux" {
return false
}
if s.customAddr != nil {
return false
}
if s.listenPort == defaultPort {
return false
}
return true
}

View File

@@ -0,0 +1,129 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.NbXdpProg,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,129 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
NbFeatures *ebpf.Map `ebpf:"nb_features"`
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.NbFeatures,
m.NbMapDnsIp,
m.NbMapDnsPort,
m.NbWgProxySettingsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.NbXdpProg,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,51 @@
package ebpf
import (
"encoding/binary"
"net"
log "github.com/sirupsen/logrus"
)
const (
mapKeyDNSIP uint32 = 0
mapKeyDNSPort uint32 = 1
)
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
if err != nil {
return err
}
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagDnsForwarder)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeDNSFwd() error {
log.Debugf("free ebpf DNS forwarder")
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
}
func ip2int(ipString string) uint32 {
ip := net.ParseIP(ipString)
return binary.BigEndian.Uint32(ip.To4())
}

View File

@@ -0,0 +1,116 @@
package ebpf
import (
_ "embed"
"net"
"sync"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
mapKeyFeatures uint32 = 0
featureFlagWGProxy = 0b00000001
featureFlagDnsForwarder = 0b00000010
)
var (
singleton manager.Manager
singletonLock = &sync.Mutex{}
)
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
//
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
type GeneralManager struct {
lock sync.Mutex
link link.Link
featureFlags uint16
bpfObjs bpfObjects
}
// GetEbpfManagerInstance return a static eBpf Manager instance
func GetEbpfManagerInstance() manager.Manager {
singletonLock.Lock()
defer singletonLock.Unlock()
if singleton != nil {
return singleton
}
singleton = &GeneralManager{}
return singleton
}
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
tf.featureFlags = tf.featureFlags | feature
}
func (tf *GeneralManager) loadXdp() error {
if tf.link != nil {
return nil
}
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
iFace, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// load pre-compiled programs into the kernel.
err = loadBpfObjects(&tf.bpfObjs, nil)
if err != nil {
return err
}
tf.link, err = link.AttachXDP(link.XDPOptions{
Program: tf.bpfObjs.NbXdpProg,
Interface: iFace.Index,
})
if err != nil {
_ = tf.bpfObjs.Close()
tf.link = nil
return err
}
return nil
}
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
tf.lock.Lock()
defer tf.lock.Unlock()
tf.featureFlags &^= feature
if tf.link == nil {
return nil
}
if tf.featureFlags == 0 {
return tf.close()
}
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
}
func (tf *GeneralManager) close() error {
log.Debugf("detach ebpf program ")
err := tf.bpfObjs.Close()
if err != nil {
log.Warnf("failed to close eBpf objects: %s", err)
}
err = tf.link.Close()
tf.link = nil
return err
}

View File

@@ -0,0 +1,40 @@
package ebpf
import (
"testing"
)
func TestManager_setFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
if mgr.featureFlags != 1 {
t.Errorf("invalid faeture state")
}
mgr.setFeatureFlag(featureFlagDnsForwarder)
if mgr.featureFlags != 3 {
t.Errorf("invalid faeture state")
}
}
func TestManager_unsetFeatureFlag(t *testing.T) {
mgr := GeneralManager{}
mgr.setFeatureFlag(featureFlagWGProxy)
mgr.setFeatureFlag(featureFlagDnsForwarder)
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 2 {
t.Errorf("invalid faeture state, expected: %d, got: %d", 2, mgr.featureFlags)
}
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if mgr.featureFlags != 0 {
t.Errorf("invalid faeture state, expected: %d, got: %d", 0, mgr.featureFlags)
}
}

View File

@@ -0,0 +1,64 @@
const __u32 map_key_dns_ip = 0;
const __u32 map_key_dns_port = 1;
struct bpf_map_def SEC("maps") nb_map_dns_ip = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u32),
.max_entries = 10,
};
struct bpf_map_def SEC("maps") nb_map_dns_port = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__be32 dns_ip = 0;
__be16 dns_port = 0;
// 13568 is 53 in big endian
__be16 GENERAL_DNS_PORT = 13568;
bool read_settings() {
__u16 *port_value;
__u32 *ip_value;
// read dns ip
ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip);
if(!ip_value) {
return false;
}
dns_ip = htonl(*ip_value);
// read dns port
port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port);
if (!port_value) {
return false;
}
dns_port = htons(*port_value);
return true;
}
int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) {
if (dns_port == 0) {
if(!read_settings()){
return XDP_PASS;
}
bpf_printk("dns port: %d", ntohs(dns_port));
bpf_printk("dns ip: %d", ntohl(dns_ip));
}
if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) {
udp->dest = dns_port;
return XDP_PASS;
}
if (udp->source == dns_port && ip->saddr == dns_ip) {
udp->source = GENERAL_DNS_PORT;
return XDP_PASS;
}
return XDP_PASS;
}

View File

@@ -0,0 +1,66 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "dns_fwd.c"
#include "wg_proxy.c"
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u16 flag_feature_wg_proxy = 0b01;
const __u16 flag_feature_dns_fwd = 0b10;
const __u32 map_key_features = 0;
struct bpf_map_def SEC("maps") nb_features = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
SEC("xdp")
int nb_xdp_prog(struct xdp_md *ctx) {
__u16 *features;
features = bpf_map_lookup_elem(&nb_features, &map_key_features);
if (!features) {
return XDP_PASS;
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
// skip non UPD packages
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
if (*features & flag_feature_dns_fwd) {
xdp_dns_fwd(ip, udp);
}
if (*features & flag_feature_wg_proxy) {
xdp_wg_proxy(ip, udp);
}
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@@ -0,0 +1,54 @@
const __u32 map_key_proxy_port = 0;
const __u32 map_key_wg_port = 1;
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__u16 proxy_port = 0;
__u16 wg_port = 0;
bool read_port_settings() {
__u16 *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
if (!value) {
return false;
}
proxy_port = *value;
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
if (!value) {
return false;
}
wg_port = htons(*value);
return true;
}
int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) {
if (proxy_port == 0 || wg_port == 0) {
if (!read_port_settings()){
return XDP_PASS;
}
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
}
// 2130706433 = 127.0.0.1
if (ip->daddr != htonl(2130706433)) {
return XDP_PASS;
}
if (udp->source != wg_port){
return XDP_PASS;
}
__be16 new_src_port = udp->dest;
__be16 new_dst_port = htons(proxy_port);
udp->dest = new_dst_port;
udp->source = new_src_port;
return XDP_PASS;
}

View File

@@ -0,0 +1,41 @@
package ebpf
import log "github.com/sirupsen/logrus"
const (
mapKeyProxyPort uint32 = 0
mapKeyWgPort uint32 = 1
)
func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error {
log.Debugf("load ebpf WG proxy")
tf.lock.Lock()
defer tf.lock.Unlock()
err := tf.loadXdp()
if err != nil {
return err
}
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
if err != nil {
return err
}
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
if err != nil {
return err
}
tf.setFeatureFlag(featureFlagWGProxy)
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
if err != nil {
return err
}
return nil
}
func (tf *GeneralManager) FreeWGProxy() error {
log.Debugf("free ebpf WG proxy")
return tf.unsetFeatureFlag(featureFlagWGProxy)
}

View File

@@ -0,0 +1,15 @@
//go:build !android
package ebpf
import (
"github.com/netbirdio/netbird/client/internal/ebpf/ebpf"
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal
// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant
// panic on older Android version.
func GetEbpfManagerInstance() manager.Manager {
return ebpf.GetEbpfManagerInstance()
}

View File

@@ -0,0 +1,10 @@
//go:build !linux || android
package ebpf
import "github.com/netbirdio/netbird/client/internal/ebpf/manager"
// GetEbpfManagerInstance return error because ebpf is not supported on all os
func GetEbpfManagerInstance() manager.Manager {
panic("unsupported os")
}

View File

@@ -0,0 +1,9 @@
package manager
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
type Manager interface {
LoadDNSFwd(ip string, dnsPort int) error
FreeDNSFwd() error
LoadWgProxy(proxyPort, wgPort int) error
FreeWGProxy() error
}

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -102,6 +102,7 @@ type Engine struct {
ctx context.Context ctx context.Context
wgInterface *iface.WGIface wgInterface *iface.WGIface
wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
udpMuxConn io.Closer udpMuxConn io.Closer
@@ -132,6 +133,7 @@ func NewEngine(
signalClient signal.Client, mgmClient mgm.Client, signalClient signal.Client, mgmClient mgm.Client,
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@@ -146,6 +148,7 @@ func NewEngine(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
} }
} }
@@ -282,7 +285,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok { if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") { if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p) modified = append(modified, p)
continue continue
} }
@@ -795,9 +798,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
// we might have received new STUN and TURN servers meanwhile, so update them // we might have received new STUN and TURN servers meanwhile, so update them
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
conf := conn.GetConf() conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
conf.StunTurn = append(e.STUNs, e.TURNs...)
conn.UpdateConf(conf)
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
err := conn.Open() err := conn.Open()
@@ -826,9 +827,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
stunTurn = append(stunTurn, e.STUNs...) stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...) stunTurn = append(stunTurn, e.TURNs...)
proxyConfig := proxy.Config{ wgConfig := peer.WgConfig{
RemoteKey: pubKey, RemoteKey: pubKey,
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort), WgListenPort: e.config.WgPort,
WgInterface: e.wgInterface, WgInterface: e.wgInterface,
AllowedIps: allowedIPs, AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey, PreSharedKey: e.config.PreSharedKey,
@@ -845,13 +846,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
Timeout: timeout, Timeout: timeout,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
ProxyConfig: proxyConfig, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(), UserspaceBind: e.wgInterface.IsUserspaceBind(),
} }
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover) peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -991,7 +992,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping) log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
break break
} }
if externalIP != nil {
mappedIP := externalIP.String() mappedIP := externalIP.String()
if internalIP != nil { if internalIP != nil {
mappedIP = mappedIP + "/" + internalIP.String() mappedIP = mappedIP + "/" + internalIP.String()
@@ -999,7 +999,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
mappedIPs = append(mappedIPs, mappedIP) mappedIPs = append(mappedIPs, mappedIP)
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP) log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
} }
}
if len(mappedIPs) != len(e.config.NATExternalIPs) { if len(mappedIPs) != len(e.config.NATExternalIPs) {
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings") log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")
return nil return nil
@@ -1008,6 +1007,10 @@ func (e *Engine) parseNATExternalIPMappings() []string {
} }
func (e *Engine) close() { func (e *Engine) close() {
if err := e.wgProxyFactory.Free(); err != nil {
log.Errorf("failed closing ebpf proxy: %s", err)
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil { if err := e.wgInterface.Close(); err != nil {

View File

@@ -367,9 +367,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Errorf("expecting Engine.peerConns to contain peer %s", p) t.Errorf("expecting Engine.peerConns to contain peer %s", p)
} }
expectedAllowedIPs := strings.Join(p.AllowedIps, ",") expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
if conn.GetConf().ProxyConfig.AllowedIps != expectedAllowedIPs { if conn.WgConfig().AllowedIps != expectedAllowedIPs {
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(), t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
expectedAllowedIPs, conn.GetConf().ProxyConfig.AllowedIps) expectedAllowedIPs, conn.WgConfig().AllowedIps)
} }
} }
}) })
@@ -1046,7 +1046,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
peersUpdateManager := server.NewPeersUpdateManager() peersUpdateManager := server.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", nil return nil, "", err
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore) eventStore)
@@ -1054,7 +1054,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -1,286 +0,0 @@
package internal
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
)
// OAuthClient is a OAuth client interface for various idp providers
type OAuthClient interface {
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
// HTTPClient http client interface for API calls
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// DeviceAuthInfo holds information for the OAuth device login flow
type DeviceAuthInfo struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// HostedGrantType grant type for device flow on Hosted
const (
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
HostedRefreshGrant = "refresh_token"
)
// Hosted client
type Hosted struct {
providerConfig ProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
type RequestDeviceCodePayload struct {
Audience string `json:"audience"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Hosted token's response
type TokenRequestResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
TokenInfo
}
// Claims used when validating the access token
type Claims struct {
Audience interface{} `json:"aud"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
func (t TokenInfo) GetTokenToUse() string {
if t.UseIDToken {
return t.IDToken
}
return t.AccessToken
}
// NewHostedDeviceFlow returns an Hosted OAuth client
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
return &Hosted{
providerConfig: config,
HTTPClient: httpClient,
}
}
// GetClientID returns the provider client id
func (h *Hosted) GetClientID(ctx context.Context) string {
return h.providerConfig.ClientID
}
// RequestDeviceCode requests a device code login flow information from Hosted
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
form := url.Values{}
form.Add("client_id", h.providerConfig.ClientID)
form.Add("audience", h.providerConfig.Audience)
form.Add("scope", h.providerConfig.Scope)
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := h.HTTPClient.Do(req)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
}
if res.StatusCode != 200 {
return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
}
deviceCode := DeviceAuthInfo{}
err = json.Unmarshal(body, &deviceCode)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
}
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err
}
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", h.providerConfig.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := h.HTTPClient.Do(req)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
}
defer func() {
err := res.Body.Close()
if err != nil {
return
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
}
if res.StatusCode > 499 {
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
return tokenResponse, nil
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := h.requestToken(info)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
UseIDToken: h.providerConfig.UseIDToken,
}
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, err
}
}
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
return fmt.Errorf("token received is empty")
}
encodedClaims := strings.Split(token, ".")[1]
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
if err != nil {
return err
}
claims := Claims{}
err = json.Unmarshal(claimsString, &claims)
if err != nil {
return err
}
if claims.Audience == nil {
return fmt.Errorf("required token field audience is absent")
}
// Audience claim of JWT can be a string or an array of strings
typ := reflect.TypeOf(claims.Audience)
switch typ.Kind() {
case reflect.String:
if claims.Audience == audience {
return nil
}
case reflect.Slice:
for _, aud := range claims.Audience.([]interface{}) {
if audience == aud {
return nil
}
}
}
return fmt.Errorf("invalid JWT token audience field")
}

View File

@@ -10,9 +10,10 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@@ -23,8 +24,18 @@ import (
const ( const (
iceKeepAliveDefault = 4 * time.Second iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second
defaultWgKeepAlive = 25 * time.Second
) )
type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
// ConnConfig is a peer Connection configuration // ConnConfig is a peer Connection configuration
type ConnConfig struct { type ConnConfig struct {
@@ -43,7 +54,7 @@ type ConnConfig struct {
Timeout time.Duration Timeout time.Duration
ProxyConfig proxy.Config WgConfig WgConfig
UDPMux ice.UDPMux UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux UDPMuxSrflx ice.UniversalUDPMux
@@ -98,7 +109,9 @@ type Conn struct {
statusRecorder *Status statusRecorder *Status
proxy proxy.Proxy wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy
remoteModeCh chan ModeMessage remoteModeCh chan ModeMessage
meta meta meta meta
@@ -122,14 +135,19 @@ func (conn *Conn) GetConf() ConnConfig {
return conn.config return conn.config
} }
// UpdateConf updates the connection config // WgConfig returns the WireGuard config
func (conn *Conn) UpdateConf(conf ConnConfig) { func (conn *Conn) WgConfig() WgConfig {
conn.config = conf return conn.config.WgConfig
}
// UpdateStunTurn update the turn and stun addresses
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
conn.config.StunTurn = turnStun
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) { func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{ return &Conn{
config: config, config: config,
mu: sync.Mutex{}, mu: sync.Mutex{},
@@ -139,6 +157,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
remoteAnswerCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1), remoteModeCh: make(chan ModeMessage, 1),
wgProxyFactory: wgProxyFactory,
adapter: adapter, adapter: adapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
}, nil }, nil
@@ -215,12 +234,12 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
func (conn *Conn) Open() error { func (conn *Conn) Open() error {
log.Debugf("trying to connect to peer %s", conn.config.Key) log.Debugf("trying to connect to peer %s", conn.config.Key)
peerState := State{PubKey: conn.config.Key} peerState := State{
PubKey: conn.config.Key,
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0] IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
peerState.ConnStatusUpdate = time.Now() ConnStatusUpdate: time.Now(),
peerState.ConnStatus = conn.status ConnStatus: conn.status,
}
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -275,10 +294,11 @@ func (conn *Conn) Open() error {
defer conn.notifyDisconnected() defer conn.notifyDisconnected()
conn.mu.Unlock() conn.mu.Unlock()
peerState = State{PubKey: conn.config.Key} peerState = State{
PubKey: conn.config.Key,
peerState.ConnStatus = conn.status ConnStatus: conn.status,
peerState.ConnStatusUpdate = time.Now() ConnStatusUpdate: time.Now(),
}
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -309,19 +329,12 @@ func (conn *Conn) Open() error {
remoteWgPort = remoteOfferAnswer.WgListenPort remoteWgPort = remoteOfferAnswer.WgListenPort
} }
// the ice connection has been established successfully so we are ready to start the proxy // the ice connection has been established successfully so we are ready to start the proxy
err = conn.startProxy(remoteConn, remoteWgPort) remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
if err != nil { if err != nil {
return err return err
} }
if conn.proxy.Type() == proxy.TypeDirectNoProxy { log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
} else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
}
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine) // wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select { select {
@@ -338,54 +351,60 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error { func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
var pair *ice.CandidatePair
pair, err := conn.agent.GetSelectedCandidatePair() pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
return err return nil, err
} }
peerState := State{PubKey: conn.config.Key} var endpoint net.Addr
p := conn.getProxy(pair, remoteWgPort) if isRelayCandidate(pair.Local) {
conn.proxy = p log.Debugf("setup relay connection")
err = p.Start(remoteConn) conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return err return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
return nil, err
} }
conn.status = StatusConnected conn.status = StatusConnected
peerState.ConnStatus = conn.status peerState := State{
peerState.ConnStatusUpdate = time.Now() PubKey: conn.config.Key,
peerState.LocalIceCandidateType = pair.Local.Type().String() ConnStatus: conn.status,
peerState.RemoteIceCandidateType = pair.Remote.Type().String() ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
Direct: !isRelayCandidate(pair.Local),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
} }
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err) log.Warnf("unable to save peer's state, got error: %v", err)
} }
return nil return endpoint, nil
}
// todo rename this method and the proxy package to something more appropriate
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
if isRelayCandidate(pair.Local) {
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
}
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
return proxy.NewNoProxy(conn.config.ProxyConfig)
} }
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -414,22 +433,22 @@ func (conn *Conn) cleanup() error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
var err1, err2, err3 error
if conn.agent != nil { if conn.agent != nil {
err := conn.agent.Close() err1 = conn.agent.Close()
if err != nil { if err1 == nil {
return err
}
conn.agent = nil conn.agent = nil
} }
}
if conn.proxy != nil { if conn.wgProxy != nil {
err := conn.proxy.Close() err2 = conn.wgProxy.CloseConn()
if err != nil { conn.wgProxy = nil
return err
}
conn.proxy = nil
} }
// todo: is it problem if we try to remove a peer what is never existed?
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.notifyDisconnected != nil { if conn.notifyDisconnected != nil {
conn.notifyDisconnected() conn.notifyDisconnected()
conn.notifyDisconnected = nil conn.notifyDisconnected = nil
@@ -437,10 +456,11 @@ func (conn *Conn) cleanup() error {
conn.status = StatusDisconnected conn.status = StatusDisconnected
peerState := State{PubKey: conn.config.Key} peerState := State{
peerState.ConnStatus = conn.status PubKey: conn.config.Key,
peerState.ConnStatusUpdate = time.Now() ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
@@ -449,8 +469,13 @@ func (conn *Conn) cleanup() error {
} }
log.Debugf("cleaned up connection to peer %s", conn.config.Key) log.Debugf("cleaned up connection to peer %s", conn.config.Key)
if err1 != nil {
return nil return err1
}
if err2 != nil {
return err2
}
return err3
} }
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer // SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer

View File

@@ -5,12 +5,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@@ -20,7 +19,6 @@ var connConf = ConnConfig{
StunTurn: []*ice.URL{}, StunTurn: []*ice.URL{},
InterfaceBlackList: nil, InterfaceBlackList: nil,
Timeout: time.Second, Timeout: time.Second,
ProxyConfig: proxy.Config{},
LocalWgPort: 51820, LocalWgPort: 51820,
} }
@@ -37,7 +35,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
conn, err := NewConn(connConf, nil, nil, nil) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -48,8 +50,11 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -82,8 +87,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -115,8 +123,11 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -142,8 +153,11 @@ func TestConn_Status(t *testing.T) {
} }
func TestConn_Close(t *testing.T) { func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }

View File

@@ -17,6 +17,7 @@ type notifier struct {
listener Listener listener Listener
currentClientState bool currentClientState bool
lastNotification int lastNotification int
lastNumberOfPeers int
} }
func newNotifier() *notifier { func newNotifier() *notifier {
@@ -29,6 +30,7 @@ func (n *notifier) setListener(listener Listener) {
n.serverStateLock.Lock() n.serverStateLock.Lock()
n.notifyListener(listener, n.lastNotification) n.notifyListener(listener, n.lastNotification)
listener.OnPeersListChanged(n.lastNumberOfPeers)
n.serverStateLock.Unlock() n.serverStateLock.Unlock()
n.listener = listener n.listener = listener
@@ -59,7 +61,7 @@ func (n *notifier) clientStart() {
n.serverStateLock.Lock() n.serverStateLock.Lock()
defer n.serverStateLock.Unlock() defer n.serverStateLock.Unlock()
n.currentClientState = true n.currentClientState = true
n.lastNotification = stateConnected n.lastNotification = stateConnecting
n.notify(n.lastNotification) n.notify(n.lastNotification)
} }
@@ -112,7 +114,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
return stateConnected return stateConnected
} }
if !managementConn && !signalConn { if !managementConn && !signalConn && !n.currentClientState {
return stateDisconnected return stateDisconnected
} }
@@ -124,6 +126,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
} }
func (n *notifier) peerListChanged(numOfPeers int) { func (n *notifier) peerListChanged(numOfPeers int) {
n.lastNumberOfPeers = numOfPeers
n.listenersLock.Lock() n.listenersLock.Lock()
defer n.listenersLock.Unlock() defer n.listenersLock.Unlock()
if n.listener == nil { if n.listener == nil {

View File

@@ -353,9 +353,13 @@ func (d *Status) onConnectionChanged() {
} }
func (d *Status) notifyPeerListChanged() { func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(len(d.peers) + len(d.offlinePeers)) d.notifier.peerListChanged(d.numOfPeers())
} }
func (d *Status) notifyAddressChanged() { func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
} }
func (d *Status) numOfPeers() int {
return len(d.peers) + len(d.offlinePeers)
}

View File

@@ -0,0 +1,128 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/management/client"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -1,72 +0,0 @@
package proxy
import (
"context"
log "github.com/sirupsen/logrus"
"net"
"time"
)
// DummyProxy just sends pings to the RemoteKey peer and reads responses
type DummyProxy struct {
conn net.Conn
remote string
ctx context.Context
cancel context.CancelFunc
}
func NewDummyProxy(remote string) *DummyProxy {
p := &DummyProxy{remote: remote}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *DummyProxy) Close() error {
p.cancel()
return nil
}
func (p *DummyProxy) Start(remoteConn net.Conn) error {
p.conn = remoteConn
go func() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Read(buf)
if err != nil {
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
return
}
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
}
}
}()
go func() {
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Write([]byte("hello"))
//log.Debugf("sent ping to %s", p.remote)
if err != nil {
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
return
}
time.Sleep(5 * time.Second)
}
}
}()
return nil
}
func (p *DummyProxy) Type() Type {
return TypeDummy
}

View File

@@ -1,42 +0,0 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// NoProxy is used just to configure WireGuard without any local proxy in between.
// Used when the WireGuard interface is userspace and uses bind.ICEBind
type NoProxy struct {
config Config
}
// NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
}
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
}
func (p *NoProxy) Type() Type {
return TypeNoProxy
}

View File

@@ -1,35 +0,0 @@
package proxy
import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io"
"net"
"time"
)
const DefaultWgKeepAlive = 25 * time.Second
type Type string
const (
TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
)
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
type Proxy interface {
io.Closer
// Start creates a local remoteConn and starts proxying data from/to remoteConn
Start(remoteConn net.Conn) error
Type() Type
}

View File

@@ -1,128 +0,0 @@
package proxy
import (
"context"
log "github.com/sirupsen/logrus"
"net"
)
// WireGuardProxy proxies
type WireGuardProxy struct {
ctx context.Context
cancel context.CancelFunc
config Config
remoteConn net.Conn
localConn net.Conn
}
func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireGuardProxy{config: config}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *WireGuardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil {
return err
}
// add local proxy connection as a Wireguard peer
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
udpAddr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn
var err error
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return err
}
err = p.updateEndpoint()
if err != nil {
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
return err
}
go p.proxyToRemote()
go p.proxyToLocal()
return nil
}
func (p *WireGuardProxy) Close() error {
p.cancel()
if c := p.localConn; c != nil {
err := p.localConn.Close()
if err != nil {
return err
}
}
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WireGuardProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WireGuardProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}
func (p *WireGuardProxy) Type() Type {
return TypeWireGuard
}

View File

@@ -155,7 +155,10 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey) state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil || state.ConnStatus != peer.StatusConnected { if err != nil {
return err
}
if state.ConnStatus != peer.StatusConnected {
return nil return nil
} }

View File

@@ -7,6 +7,8 @@ import (
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
const ( const (
@@ -26,15 +28,20 @@ func genKey(format string, input string) string {
return fmt.Sprintf(format, input) return fmt.Sprintf(format, input)
} }
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager // newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) firewallManager { func newFirewall(parentCTX context.Context) (firewallManager, error) {
manager, err := newNFTablesManager(parentCTX) checkResult := checkfw.Check()
if err == nil { switch checkResult {
log.Debugf("nftables firewall manager will be used") case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
return manager log.Debug("creating an iptables firewall manager for route rules")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
return newIptablesManager(parentCTX, ipv6Supported)
case checkfw.NFTABLES:
log.Info("creating an nftables firewall manager for route rules")
return newNFTablesManager(parentCTX), nil
} }
log.Debugf("fallback to iptables firewall manager: %s", err)
return newIptablesManager(parentCTX) return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
} }
func getInPair(pair routerPair) routerPair { func getInPair(pair routerPair) routerPair {

View File

@@ -3,24 +3,13 @@
package routemanager package routemanager
import "context" import (
"context"
"fmt"
"runtime"
)
type unimplementedFirewall struct{} // newFirewall returns a nil manager
func newFirewall(context.Context) (firewallManager, error) {
func (unimplementedFirewall) RestoreOrCreateContainers() error { return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
return nil
}
func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) CleanRoutingRules() {
}
// NewFirewall returns an unimplemented Firewall manager
func NewFirewall(parentCtx context.Context) firewallManager {
return unimplementedFirewall{}
} }

View File

@@ -49,26 +49,28 @@ type iptablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newIptablesManager(parentCtx context.Context) *iptablesManager { func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
ctx, cancel := context.WithCancel(parentCtx) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil {
if !isIptablesClientAvailable(ipv4Client) { return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
log.Infof("iptables is missing for ipv4")
ipv4Client = nil
}
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if !isIptablesClientAvailable(ipv6Client) {
log.Infof("iptables is missing for ipv6")
ipv6Client = nil
} }
return &iptablesManager{ ctx, cancel := context.WithCancel(parentCtx)
manager := &iptablesManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
ipv4Client: ipv4Client, ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string), rules: make(map[string]map[string][]string),
} }
if ipv6Supported {
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
}
}
return manager, nil
} }
// CleanRoutingRules cleans existing iptables resources that we created by the agent // CleanRoutingRules cleans existing iptables resources that we created by the agent
@@ -391,6 +393,10 @@ func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string
ipVersion = ipv6 ipVersion = ipv6
} }
if iptablesClient == nil {
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination) rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey] existingRule, found := i.rules[ipVersion][ruleKey]
@@ -455,6 +461,10 @@ func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair
ipVersion = ipv6 ipVersion = ipv6
} }
if iptablesClient == nil {
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey] existingRule, found := i.rules[ipVersion][ruleKey]
if found { if found {
@@ -475,8 +485,3 @@ func getIptablesRuleType(table string) string {
} }
return ruleType return ruleType
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
manager := newIptablesManager(context.TODO()) manager, err := newIptablesManager(context.TODO(), true)
require.NoError(t, err, "should return a valid iptables manager")
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")

View File

@@ -27,7 +27,7 @@ type DefaultManager struct {
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[string]*clientNetwork clientNetworks map[string]*clientNetwork
serverRouter *serverRouter serverRouter serverRouter
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
pubKey string pubKey string
@@ -36,13 +36,17 @@ type DefaultManager struct {
// NewManager returns a new route manager // NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) srvRouter, err := newServerRouter(ctx, wgInterface)
if err != nil {
log.Errorf("server router is not supported: %s", err)
}
mCTX, cancel := context.WithCancel(ctx)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface), serverRouter: srvRouter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
@@ -59,7 +63,9 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() { func (m *DefaultManager) Stop() {
m.stop() m.stop()
if m.serverRouter != nil {
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
}
m.ctx = nil m.ctx = nil
} }
@@ -77,10 +83,13 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.updateClientNetworks(updateSerial, newClientRoutesIDMap) m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap) m.notifier.onNewRoutes(newClientRoutesIDMap)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if err != nil {
return err return err
} }
}
return nil return nil
} }

View File

@@ -3,11 +3,12 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip" "net/netip"
"runtime" "runtime"
"testing" "testing"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -30,7 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputInitRoutes []*route.Route inputInitRoutes []*route.Route
inputRoutes []*route.Route inputRoutes []*route.Route
inputSerial uint64 inputSerial uint64
shouldCheckServerRoutes bool removeSrvRouter bool
serverRoutesExpected int serverRoutesExpected int
clientNetworkWatchersExpected int clientNetworkWatchersExpected int
}{ }{
@@ -87,7 +88,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
}, },
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2, serverRoutesExpected: 2,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
}, },
@@ -116,10 +116,38 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
}, },
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 1, serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1, clientNetworkWatchersExpected: 1,
}, },
{
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
removeSrvRouter: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 1,
},
{ {
name: "Should Create 1 HA Route and 1 Standalone", name: "Should Create 1 HA Route and 1 Standalone",
inputRoutes: []*route.Route{ inputRoutes: []*route.Route{
@@ -174,25 +202,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputSerial: 1, inputSerial: 1,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
}, },
{
name: "No Server Routes Should Be Added To Non Linux",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("1.2.3.4/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS != "linux",
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
{ {
name: "Remove 1 Client Route", name: "Remove 1 Client Route",
inputInitRoutes: []*route.Route{ inputInitRoutes: []*route.Route{
@@ -335,7 +344,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
inputRoutes: []*route.Route{}, inputRoutes: []*route.Route{},
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: true,
serverRoutesExpected: 0, serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
}, },
@@ -384,7 +392,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
}, },
}, },
inputSerial: 1, inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2, serverRoutesExpected: 2,
clientNetworkWatchersExpected: 1, clientNetworkWatchersExpected: 1,
}, },
@@ -409,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
defer routeManager.Stop() defer routeManager.Stop()
if testCase.removeSrvRouter {
routeManager.serverRouter = nil
}
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
@@ -419,8 +430,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes { if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") sr := routeManager.serverRouter.(*defaultServerRouter)
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
} }
}) })
} }

View File

@@ -86,10 +86,10 @@ type nftablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { func newNFTablesManager(parentCtx context.Context) *nftablesManager {
ctx, cancel := context.WithCancel(parentCtx) ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{ return &nftablesManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
conn: &nftables.Conn{}, conn: &nftables.Conn{},
@@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2), defaultForwardRules: make([]*nftables.Rule, 2),
} }
err := mgr.isSupported()
if err != nil {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
} }
// CleanRoutingRules cleans existing nftables rules from the system // CleanRoutingRules cleans existing nftables rules from the system
@@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error {
} }
for _, table := range tables { for _, table := range tables {
if table.Name == "filter" {
n.filterTable = table
continue
}
if table.Name == nftablesTable { if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 { if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table n.tableIPv4 = table
@@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil return nil
} }
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error { func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil { if n.defaultForwardRules[0] == nil {
return nil return nil
@@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
return nil return nil
} }
func (n *nftablesManager) isSupported() error {
_, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables is not supported: %s", err)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@@ -10,20 +10,23 @@ import (
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) if checkfw.Check() != checkfw.NFTABLES {
if err != nil { t.Skip("nftables not supported on this OS")
t.Fatalf("failed to create nftables manager: %s", err)
} }
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
@@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
} }
func TestNftablesManager_InsertRoutingRules(t *testing.T) { func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range insertRuleTestCases { for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) manager := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair) err = manager.InsertRoutingRules(testCase.inputPair)
@@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
} }
func TestNftablesManager_RemoveRoutingRules(t *testing.T) { func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range removeRuleTestCases { for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) manager := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4 table := manager.tableIPv4

View File

@@ -0,0 +1,9 @@
package routemanager
import "github.com/netbirdio/netbird/route"
type serverRouter interface {
updateRoutes(map[string]*route.Route) error
removeFromServerNetwork(*route.Route) error
cleanUp()
}

View File

@@ -2,20 +2,11 @@ package routemanager
import ( import (
"context" "context"
"fmt"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
) )
type serverRouter struct { func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
} }
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@@ -13,7 +13,7 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type serverRouter struct { type defaultServerRouter struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[string]*route.Route routes map[string]*route.Route
@@ -21,16 +21,21 @@ type serverRouter struct {
wgInterface *iface.WGIface wgInterface *iface.WGIface
} }
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
return &serverRouter{ firewall, err := newFirewall(ctx)
if err != nil {
return nil, err
}
return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[string]*route.Route), routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx), firewall: firewall,
wgInterface: wgInterface, wgInterface: wgInterface,
} }, nil
} }
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error { func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0) serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 { if len(routesMap) > 0 {
@@ -81,7 +86,7 @@ func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil return nil
} }
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not removing from server network because context is done") log.Infof("not removing from server network because context is done")
@@ -98,7 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
} }
} }
func (m *serverRouter) addToServerNetwork(route *route.Route) error { func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not adding to server network because context is done") log.Infof("not adding to server network because context is done")
@@ -115,6 +120,6 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
} }
} }
func (m *serverRouter) cleanUp() { func (m *defaultServerRouter) cleanUp() {
m.firewall.CleanRoutingRules() m.firewall.CleanRoutingRules()
} }

View File

@@ -20,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool {
for _, s := range disallowList { for _, s := range disallowList {
if strings.HasPrefix(iFace, s) { if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace) log.Tracef("ignoring interface %s - it is not allowed", iFace)
return false return false
} }
} }

View File

@@ -0,0 +1,8 @@
package templates
import (
_ "embed"
)
//go:embed pkce-auth-msg.html
var PKCEAuthMsgTmpl string

View File

@@ -0,0 +1,87 @@
<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<style>
body {
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
background: #f7f8f9;
font-family: sans-serif, Arial, Tahoma;
}
.container {
width: 100%;
background: white;
border: 1px solid #e8e9ea;
text-align: center;
padding: 20px;
padding-bottom: 50px;
max-width: 550px;
margin: 0 10px;
}
.logo {
height: 80px;
border-bottom: 1px solid #e8e9ea;
display: flex;
justify-content: center;
align-items: center;
}
.logo img {
width: 130px;
}
.content {
font-size: 13px;
color: #525252;
line-height: 18px;
padding: 10px 0;
}
.content div {
font-size: 18px;
line-height: normal;
margin-bottom: 5px;
color: black;
}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<img src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
</div>
<br>
{{ if .Error }}
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="red" stroke-width="3"/>
<path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/>
</svg>
<div class="content">
<div>
Login failed
</div>
{{ .Error }}.
</div>
{{ else }}
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="#5cb85c" stroke-width="3"/>
<path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/>
</svg>
<div class="content">
<div>
Login successful
</div>
Your device is now registered and logged in to NetBird.
<br>
You can now close this window.
</div>
{{ end }}
</div>
</body>
</html>

View File

@@ -0,0 +1,20 @@
package wgproxy
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.Free()
}
return nil
}

View File

@@ -0,0 +1,21 @@
//go:build !android
package wgproxy
import (
log "github.com/sirupsen/logrus"
)
func NewFactory(wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
ebpfProxy := NewWGEBPFProxy(wgPort)
err := ebpfProxy.Listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
f.ebpfProxy = ebpfProxy
return f
}

View File

@@ -0,0 +1,7 @@
//go:build !linux || android
package wgproxy
func NewFactory(wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}

View File

@@ -0,0 +1,32 @@
package wgproxy
import (
"fmt"
"net"
)
const (
portRangeStart = 3128
portRangeEnd = 3228
)
type portLookup struct {
}
func (pl portLookup) searchFreePort() (int, error) {
for i := portRangeStart; i <= portRangeEnd; i++ {
if pl.tryToBind(i) == nil {
return i, nil
}
}
return 0, fmt.Errorf("failed to bind free port for eBPF proxy")
}
func (pl portLookup) tryToBind(port int) error {
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
_ = l.Close()
return nil
}

View File

@@ -0,0 +1,42 @@
package wgproxy
import (
"fmt"
"net"
"testing"
)
func Test_portLookup_searchFreePort(t *testing.T) {
pl := portLookup{}
_, err := pl.searchFreePort()
if err != nil {
t.Fatal(err)
}
}
func Test_portLookup_on_allocated(t *testing.T) {
pl := portLookup{}
allocatedPort, err := allocatePort(portRangeStart)
if err != nil {
t.Fatal(err)
}
defer allocatedPort.Close()
fp, err := pl.searchFreePort()
if err != nil {
t.Fatal(err)
}
if fp != (portRangeStart + 1) {
t.Errorf("invalid free port, expected: %d, got: %d", portRangeStart+1, fp)
}
}
func allocatePort(port int) (net.PacketConn, error) {
c, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
return c, err
}

View File

@@ -0,0 +1,12 @@
package wgproxy
import (
"net"
)
// Proxy is a transfer layer between the Turn connection and the WireGuard
type Proxy interface {
AddTurnConn(urnConn net.Conn) (net.Addr, error)
CloseConn() error
Free() error
}

View File

@@ -0,0 +1,255 @@
//go:build linux && !android
package wgproxy
import (
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
lastUsedPort uint16
localWGListenPort int
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn *net.UDPConn
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
return wgProxy
}
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.rawConn, err = p.prepareSenderRawSocket()
if err != nil {
return err
}
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil {
return err
}
addr := net.UDPAddr{
Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"),
}
p.conn, err = net.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
if cErr != nil {
log.Errorf("failed to close the wgproxy: %s", cErr)
}
return err
}
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
return nil
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy")
var err1, err2, err3 error
if p.conn != nil {
err1 = p.conn.Close()
}
err2 = p.ebpfManager.FreeWGProxy()
if p.rawConn != nil {
err3 = p.rawConn.Close()
}
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
buf := make([]byte, 1500)
for {
n, err := remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
p.removeTurnConn(endpointPort)
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Errorf("turn conn not found by port: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
}
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
np, err := p.nextFreePort()
if err != nil {
return np, err
}
p.turnConnStore[np] = turnConn
return np, nil
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Tracef("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
if len(p.turnConnStore) == 65535 {
return 0, fmt.Errorf("reached maximum turn connection numbers")
}
generatePort:
if p.lastUsedPort == 65535 {
p.lastUsedPort = 1
} else {
p.lastUsedPort++
}
if _, ok := p.turnConnStore[p.lastUsedPort]; ok {
goto generatePort
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, err
}
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, err
}
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, err
}
return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)))
}
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localhost,
SrcIP: localhost,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return err
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return err
}
_, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost})
return err
}

View File

@@ -0,0 +1,56 @@
//go:build linux && !android
package wgproxy
import (
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
t.Errorf("invalid initial port: %d", wgProxy.lastUsedPort)
}
numOfConns := 10
for i := 0; i < numOfConns; i++ {
p, _ = wgProxy.storeTurnConn(nil)
}
if p != uint16(numOfConns)+1 {
t.Errorf("invalid last used port: %d, expected: %d", p, numOfConns+1)
}
if len(wgProxy.turnConnStore) != numOfConns+1 {
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), numOfConns+1)
}
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
p, _ := wgProxy.storeTurnConn(nil)
if len(wgProxy.turnConnStore) != 2 {
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), 2)
}
if p != 2 {
t.Errorf("invalid last used port: %d, expected: %d", p, 2)
}
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)
}
_, err := wgProxy.storeTurnConn(nil)
if err == nil {
t.Errorf("invalid turn conn store calculation")
}
}

View File

@@ -0,0 +1,106 @@
package wgproxy
import (
"context"
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("instantiate new userspace proxy")
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn
var err error
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}

77
client/netbird.wxs Normal file
View File

@@ -0,0 +1,77 @@
<Wix
xmlns="http://wixtoolset.org/schemas/v4/wxs">
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="Wiretrustee UG (haftungsbeschreankt)" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
<MediaTemplate EmbedCab="yes" />
<Feature Id="NetbirdFeature" Title="Netbird" Level="1">
<ComponentGroupRef Id="NetbirdFilesComponent" />
</Feature>
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
<StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File>
<File ProcessorArchitecture="x64" Source=".\dist\netbird_windows_amd64\wintun.dll" />
<ServiceInstall
Id="NetBirdService"
Name="NetBird"
DisplayName="NetBird"
Description="A WireGuard-based mesh network that connects your devices into a single private network."
Start="auto" Type="ownProcess"
ErrorControl="normal"
Account="LocalSystem"
Vital="yes"
Interactive="no"
Arguments='service run config [CommonAppDataFolder]Netbird\config.json log-level info'
/>
<ServiceControl Id="NetBirdService" Name="NetBird" Start="install" Stop="both" Remove="uninstall" Wait="yes" />
<Environment Id="UpdatePath" Name="PATH" Value="[NetbirdInstallDir]" Part="last" Action="set" System="yes" />
</Component>
</Directory>
</StandardDirectory>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
</ComponentGroup>
<Property Id="cmd" Value="cmd.exe"/>
<CustomAction Id="KillDaemon"
ExeCommand='/c "taskkill /im netbird.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<CustomAction Id="KillUI"
ExeCommand='/c "taskkill /im netbird-ui.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<InstallExecuteSequence>
<!-- For Uninstallation -->
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
</InstallExecuteSequence>
<!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\netbird.ico" />
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
</Package>
</Wix>

View File

@@ -6,6 +6,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/client/internal/auth"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@@ -38,8 +40,8 @@ type Server struct {
type oauthAuthFlow struct { type oauthAuthFlow struct {
expiresAt time.Time expiresAt time.Time
client internal.OAuthClient flow auth.OAuthFlow
info internal.DeviceAuthInfo info auth.AuthFlowInfo
waitCancel context.CancelFunc waitCancel context.CancelFunc
} }
@@ -206,28 +208,15 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, gstatus.Errorf(codes.NotFound, "no SSO provider returned from management. "+
"If you are using hosting Netbird see documentation at "+
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, gstatus.Errorf(codes.Unimplemented, "the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
} else {
log.Errorf("getting device authorization flow info failed with error: %v", err)
return nil, err return nil, err
} }
}
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous device flow info") log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{ return &proto.LoginResponse{
NeedsSSOLogin: true, NeedsSSOLogin: true,
VerificationURI: s.oauthAuthFlow.info.VerificationURI, VerificationURI: s.oauthAuthFlow.info.VerificationURI,
@@ -242,25 +231,25 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
} }
} }
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO()) authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
if err != nil { if err != nil {
log.Errorf("getting a request device code failed: %v", err) log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err return nil, err
} }
s.mutex.Lock() s.mutex.Lock()
s.oauthAuthFlow.client = hostedClient s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = deviceAuthInfo s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second) s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock() s.mutex.Unlock()
state.Set(internal.StatusNeedsLogin) state.Set(internal.StatusNeedsLogin)
return &proto.LoginResponse{ return &proto.LoginResponse{
NeedsSSOLogin: true, NeedsSSOLogin: true,
VerificationURI: deviceAuthInfo.VerificationURI, VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: deviceAuthInfo.VerificationURIComplete, VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: deviceAuthInfo.UserCode, UserCode: authInfo.UserCode,
}, nil }, nil
} }
@@ -289,8 +278,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.actCancel = cancel s.actCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
if s.oauthAuthFlow.client == nil { if s.oauthAuthFlow.flow == nil {
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized") return nil, gstatus.Errorf(codes.Internal, "oauth flow is not initialized")
} }
state := internal.CtxGetState(ctx) state := internal.CtxGetState(ctx)
@@ -304,10 +293,10 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
s.mutex.Lock() s.mutex.Lock()
deviceAuthInfo := s.oauthAuthFlow.info flowInfo := s.oauthAuthFlow.info
s.mutex.Unlock() s.mutex.Unlock()
if deviceAuthInfo.UserCode != msg.UserCode { if flowInfo.UserCode != msg.UserCode {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid") return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
} }
@@ -324,10 +313,10 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel = cancel s.oauthAuthFlow.waitCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo) tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
if err == context.Canceled { if err == context.Canceled {
return nil, nil return nil, nil //nolint:nilnil
} }
s.mutex.Lock() s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now() s.oauthAuthFlow.expiresAt = time.Now()

Some files were not shown because too many files have changed in this diff Show More