Compare commits

...

79 Commits

Author SHA1 Message Date
Maycon Santos
d9fa28d8a0 test postgres 2023-08-02 18:56:06 +02:00
Maycon Santos
144ac868e0 create empty files 2023-08-02 18:32:31 +02:00
Maycon Santos
b70339d3bd use check_jq 2023-08-02 12:32:43 +02:00
Maycon Santos
ee890971a3 update texts and expose errors to stderr 2023-08-02 12:19:35 +02:00
Maycon Santos
eeb1b619b7 Automate Zitadel IDP configuration (#1006)
Add automated zitadel configuration and netbird using a single script

run infra files test only when they change or when pushed to main

run releases only on changes to related files or when pushed to main

push install and automated getting started script to releases page
2023-08-01 23:31:14 +02:00
Maycon Santos
4b47f6b23c Merge branch 'main' into update-getting-started-flow 2023-07-30 16:39:49 +02:00
Bethuel Mmbaga
64f6343fcc Add html screen for pkce flow (#1034)
* add html screen for pkce flow

* remove unused CSS classes in pkce-auth-msg.html

* remove links to external sources
2023-07-28 18:10:12 +02:00
Maycon Santos
24713fbe59 Move ebpf code to its own package to avoid crash issues in Android (#1033)
* Move ebpf code to its own package to avoid crash issues in Android

Older versions of android crashes because of the bytecode files
Even when they aren't loaded as it was our case

* move c file to own folder

* fix lint
2023-07-27 15:34:27 +02:00
Bethuel Mmbaga
7794b744f8 Add PKCE authorization flow (#1012)
Enhance the user experience by enabling authentication to Netbird using Single Sign-On (SSO) with any Identity Provider (IDP) provider. Current client offers this capability through the Device Authorization Flow, however, is not widely supported by many IDPs, and even some that do support it do not provide a complete verification URL.

To address these challenges, this pull request enable Authorization Code Flow with Proof Key for Code Exchange (PKCE) for client logins, which is a more widely adopted and secure approach to facilitate SSO with various IDP providers.
2023-07-27 11:31:07 +02:00
Maycon Santos
0d0c30c16d Avoid compiling linux NewFactory for Android (#1032) 2023-07-26 16:21:04 +02:00
Zoltan Papp
b0364da67c Wg ebpf proxy (#911)
EBPF proxy between TURN (relay) and WireGuard to reduce number of used ports used by the NetBird agent.
- Separate the wg configuration from the proxy logic
- In case if eBPF type proxy has only one single proxy instance
- In case if the eBPF is not supported fallback to the original proxy Implementation

Between the signature of eBPF type proxy and original proxy has 
differences so this is why the factory structure exists
2023-07-26 14:00:47 +02:00
Givi Khojanashvili
6dee89379b Feat optimize acl performance iptables (#1025)
* use ipset for iptables

* Update unit-tests for iptables

* Remove debug code

* Update dependencies

* Create separate sets for dPort and sPort rules

* Fix iptables tests

* Fix 0.0.0.0 processing in iptables with ipset
2023-07-24 13:00:23 +02:00
Maycon Santos
76db4f801a Record idp manager type (#1027)
This allows to define priority on support different managers
2023-07-22 19:30:59 +02:00
Zoltan Papp
6c2ed4b4f2 Add default forward rule (#1021)
* Add default forward rule

* Fix

* Add multiple forward rules

* Fix delete rule error handling
2023-07-22 18:39:23 +02:00
Maycon Santos
2541c78dd0 Use error level for JWT parsing error logs (#1026) 2023-07-22 17:56:27 +02:00
Yury Gargay
97b6e79809 Fix DefaultAccountManager GetGroupsFromTheToken false positive tests (#1019)
This fixes the test logic creates copy of account with empty id and
re-pointing the indices to it.

Also, adds additional check for empty ID in SaveAccount method of FileStore.
2023-07-22 15:54:08 +04:00
Givi Khojanashvili
6ad3847615 Fix nfset not binds to the rule (#1024) 2023-07-21 17:45:58 +02:00
Bethuel Mmbaga
a4d830ef83 Fix Okta IDP device authorization (#1023)
* hide okta netbird attributes fields

* fix: update full user profile
2023-07-21 09:34:49 +02:00
pascal-fischer
9e540cd5b4 Merge pull request #1016 from surik/filestore-index-deletion-optimisation
Do not persist filestore when deleting indices
2023-07-20 18:07:33 +02:00
Zoltan Papp
3027d8f27e Sync the iptables/nftables usage with acl logic (#1017) 2023-07-19 19:10:27 +02:00
Givi Khojanashvili
e69ec6ab6a Optimize ACL performance (#994)
* Optimize rules with All groups

* Use IP sets in ACLs (nftables implementation)

* Fix squash rule when we receive optimized rules list from management
2023-07-18 13:12:50 +04:00
Yury Gargay
7ddde41c92 Do not persist filestore when deleting indices
As both TokenID2UserID and HashedPAT2TokenID are in-memory indices and
not stored in the file.
2023-07-17 11:52:45 +02:00
Zoltan Papp
7ebe58f20a Feature/permanent dns (#967)
* Add DNS list argument for mobile client

* Write testable code

Many places are checked the wgInterface != nil condition.
It is doing it just because to avoid the real wgInterface creation for tests.
Instead of this involve a wgInterface interface what is moc-able.

* Refactor the DNS server internal code structure

With the fake resolver has been involved several
if-else statement and generated some unused
variables to distinguish the listener and fake
resolver solutions at running time. With this
commit the fake resolver and listener based
solution has been moved into two separated
structure. Name of this layer is the 'service'.
With this modification the unit test looks
simpler and open the option to add new logic for
the permanent DNS service usage for mobile
systems.



* Remove is running check in test

We can not ensure the state well so remove this
check. The test will fail if the server is not
running well.
2023-07-14 21:56:22 +02:00
Zoltan Papp
9c2c0e7934 Check links of groups before delete it (#1010)
* Check links of groups before delete it

* Add delete group handler test

* Rename dns error msg

* Add delete group test

* Remove rule check

The policy cover this scenario

* Fix test

* Check disabled management grps

* Change error message

* Add new activity for group delete event
2023-07-14 20:45:40 +02:00
pascal-fischer
c6af1037d9 FIx error on ip6tables not available (#999)
* adding check operation to confirm if ip*tables is available

* linter

* linter
2023-07-14 20:44:35 +02:00
Bethuel Mmbaga
5cb9a126f1 Fix pre-shared key not persistent (#1011)
* update pre-shared key if new key is not empty

* add unit test for empty pre-shared key
2023-07-13 10:49:15 +02:00
pascal-fischer
f40951cdf5 Merge pull request #991 from netbirdio/fix/improve_uspfilter_performance
Improve userspace filter performance
2023-07-12 18:02:29 +02:00
Pascal Fischer
6e264d9de7 fix rule order to solve DNS resolver issue 2023-07-11 19:58:21 +02:00
Bethuel Mmbaga
42db9773f4 Remove unused netbird UI dependencies (#1007)
* remove unused netbird-ui dependencies in deb package

* build netbird-ui with support for legacy appindicator

* add rpm package dendencies

* add binary build package

* remove dependencies
2023-07-10 21:09:16 +02:00
Maycon Santos
5883e019c9 update readme 2023-07-07 17:05:21 +02:00
Bethuel Mmbaga
bb9f6f6d0a Add API Endpoint for Resending User Invitations in Auth0 (#989)
* add request handler for sending invite

* add InviteUser method to account manager interface

* add InviteUser mock

* add invite user endpoint to user handler

* add InviteUserByID to manager interface

* implement InviteUserByID in all idp managers

* resend user invitation

* add invite user handler tests

* refactor

* user userID for sending invitation

* fix typo

* refactor

* pass userId in url params
2023-07-03 12:20:19 +02:00
Yury Gargay
829ce6573e Fix broken links in README.md (#992) 2023-06-29 11:42:55 +02:00
Maycon Santos
a366d9e208 Prevent sending nameserver configuration when peer is set as NS (#962)
* Prevent sending nameserver configuration when peer is set as NS

* Add DNS filter tests
2023-06-28 17:29:02 +02:00
Pascal Fischer
e074c24487 add type for RuleSet 2023-06-28 14:09:23 +02:00
Pascal Fischer
54fe05f6d8 fix test 2023-06-28 10:35:29 +02:00
Pascal Fischer
33a155d9aa fix all rules check 2023-06-28 03:03:01 +02:00
Pascal Fischer
51878659f8 remove Rule index map 2023-06-28 02:50:12 +02:00
pascal-fischer
c000c05435 Merge pull request #983 from netbirdio/fix/ssh_connection_freeze
Fix ssh connection freeze
2023-06-27 18:10:30 +02:00
Pascal Fischer
b39ffef22c add missing all rule 2023-06-27 17:44:05 +02:00
Pascal Fischer
d96f882acb seems to work but delete fails 2023-06-27 17:26:15 +02:00
Misha Bragin
d409219b51 Don't create setup keys on new account (#972) 2023-06-27 17:17:24 +02:00
Givi Khojanashvili
8b619a8224 JWT Groups support (#966)
Get groups from the JWT tokens if the feature enabled for the account
2023-06-27 18:51:05 +04:00
Maycon Santos
ed075bc9b9 Refactor: Configurable supported scopes (#985)
* Refactor: Configurable supported scopes

Previously, supported scopes were hardcoded and limited to Auth0
and Keycloak. This update removes the default set of values,
providing flexibility. The value to be set for each Identity
Provider (IDP) is specified in their respective documentation.

* correct var

* correct var

* skip fetching scopes from openid-configuration
2023-06-25 13:59:45 +02:00
Pascal Fischer
8eb098d6fd add sleep and comment 2023-06-23 17:02:34 +02:00
Pascal Fischer
68a8687c80 fix linter 2023-06-23 16:45:07 +02:00
Pascal Fischer
f7d97b02fd fix error codes on cli 2023-06-23 16:27:10 +02:00
Pascal Fischer
2691e729cd fix ssh 2023-06-23 12:20:14 +02:00
Givi Khojanashvili
b524a9d49d Fix use wrpped device in windows (#981) 2023-06-23 10:01:22 +02:00
Givi Khojanashvili
774d8e955c Fix disabled DNS resolver fail (#978)
Fix fail of DNS when it disabled in the settings
2023-06-22 16:59:21 +04:00
Givi Khojanashvili
c20f98c8b6 ACL firewall manager fix/improvement (#970)
* ACL firewall manager fix/improvement

Fix issue with rule squashing, it contained issue when calculated
total amount of IPs in the Peer map (doesn't included offline peers).
That why squashing not worked.
Also this commit changes the rules apply behaviour. Instead policy:
1. Apply all rules from network map
2. Remove all previous applied rules
We do:
1. Apply only new rules
2. Remove outdated rules
Why first variant was implemented: because when you have drop policy
it is important in which order order you rules are and you need totally
clean previous state to apply the new. But in the release we didn't
include drop policy so we can do this improvement.

* Print log message about processed ACL rules
2023-06-20 20:33:41 +02:00
Zoltan Papp
20ae540fb1 Fix the stop procedure in DefaultDns (#971) 2023-06-20 20:33:26 +02:00
Bethuel
58cfa2bb17 Add Google Workspace IdP (#949)
Added integration with Google Workspace user directory API.
2023-06-20 19:15:36 +02:00
pascal-fischer
06005cc10e Merge pull request #968 from netbirdio/chore/extend_gitignore_for_multiple_configs
Extend gitignore to ignore multiple configs
2023-06-19 17:17:12 +02:00
Pascal Fischer
1a3e377304 extend gitignore to ignore multiple config files 2023-06-19 15:07:27 +02:00
Zoltan Papp
dd29f4c01e Reduce the peer status notifications (#956)
Reduce the peer status notifications

When receive new network map invoke multiple notifications for 
every single peers. It cause high cpu usage We handle the in a 
batch the peer notification in update network map.

- Remove the unnecessary UpdatePeerFQDN calls in addNewPeer
- Fix notification in RemovePeer function
- Involve FinishPeerListModifications logic
2023-06-19 11:20:34 +02:00
pascal-fischer
cb7ecd1cc4 Merge pull request #945 from netbirdio/feat/refactor_route_adding_in_client
Refactor check logic when adding routes
2023-06-19 10:16:22 +02:00
Pascal Fischer
b5d8142705 test windows 2023-06-12 16:22:53 +02:00
Pascal Fischer
f45eb1a1da test windows 2023-06-12 16:12:24 +02:00
Pascal Fischer
2567006412 test windows 2023-06-12 16:01:06 +02:00
Pascal Fischer
b92107efc8 test windows 2023-06-12 15:38:47 +02:00
Pascal Fischer
5d19811331 test windows 2023-06-12 15:26:28 +02:00
Pascal Fischer
697d41c94e test windows 2023-06-12 15:14:51 +02:00
Pascal Fischer
75d541f967 test windows 2023-06-12 14:56:30 +02:00
Pascal Fischer
7dfbb71f7a test windows 2023-06-12 12:49:21 +02:00
Pascal Fischer
a5d14c92ff test windows 2023-06-12 12:16:00 +02:00
Pascal Fischer
ce091ab42b test windows 2023-06-12 11:43:18 +02:00
Pascal Fischer
d2fad1cfd9 testing windows 2023-06-12 11:06:49 +02:00
Pascal Fischer
0b5594f145 testing windows 2023-06-09 19:17:26 +02:00
Pascal Fischer
9beaa91db9 testing windows 2023-06-09 19:15:39 +02:00
Pascal Fischer
c8b4c08139 split systemops for operating systems and add linux 2023-06-09 18:48:21 +02:00
Pascal Fischer
dad5501a44 split systemops for operating systems and add linux 2023-06-09 18:40:35 +02:00
Pascal Fischer
1ced2462c1 split systemops for operating systems and add linux 2023-06-09 18:36:49 +02:00
Pascal Fischer
64adaeb276 split systemops for operating systems and add linux 2023-06-09 18:30:36 +02:00
Pascal Fischer
6e26d03fb8 split systemops for operating systems and add linux 2023-06-09 18:27:09 +02:00
Pascal Fischer
493ddb4fe3 Revert "hacky all-operating-systems solution"
This reverts commit 75fac258e7.
2023-06-09 17:59:06 +02:00
Pascal Fischer
75fac258e7 hacky all-operating-systems solution 2023-06-09 17:40:10 +02:00
Pascal Fischer
bc8ee8fc3c add tests 2023-06-09 16:18:48 +02:00
Pascal Fischer
3724323f76 test still failing 2023-06-09 15:33:22 +02:00
Pascal Fischer
3ef33874b1 change checks before route adding to not only check for default gateway (test missing) 2023-06-09 12:35:57 +02:00
164 changed files with 7882 additions and 2113 deletions

View File

@@ -7,6 +7,16 @@ on:
branches: branches:
- main - main
pull_request: pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.goreleaser.yml'
- '.goreleaser_ui.yaml'
- '.goreleaser_ui_darwin.yaml'
- '.github/workflows/release.yml'
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
env: env:
SIGN_PIPE_VER: "v0.0.8" SIGN_PIPE_VER: "v0.0.8"
@@ -116,7 +126,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install rsrc - name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2 run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc - name: Generate windows rsrc

View File

@@ -1,10 +1,13 @@
name: Test Docker Compose Linux name: Test Infrastructure files
on: on:
push: push:
branches: branches:
- main - main
pull_request: pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
concurrency: concurrency:
@@ -12,7 +15,7 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
test: test-docker-compose:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install jq - name: Install jq
@@ -35,7 +38,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -53,6 +56,7 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none" CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
- name: check values - name: check values
working-directory: infrastructure_files working-directory: infrastructure_files
@@ -68,6 +72,7 @@ jobs:
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers" CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken" CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email" CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
@@ -90,8 +95,8 @@ jobs:
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE" grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
grep UseIDToken management.json | grep false grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
@@ -99,6 +104,12 @@ jobs:
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files working-directory: infrastructure_files
@@ -113,3 +124,28 @@ jobs:
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4 test $count -eq 4
working-directory: infrastructure_files working-directory: infrastructure_files
test-getting-started-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v3
- name: run script
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen
run: test -f Caddyfile
- name: test docker-compose file gen
run: test -f docker-compose.yml
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen
run: test -f turnserver.conf
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen
run: test -f dashboard.env

7
.gitignore vendored
View File

@@ -7,8 +7,15 @@ bin/
conf.json conf.json
http-cmds.sh http-cmds.sh
infrastructure_files/management.json infrastructure_files/management.json
infrastructure_files/management-*.json
infrastructure_files/docker-compose.yml infrastructure_files/docker-compose.yml
infrastructure_files/openid-configuration.json
infrastructure_files/turnserver.conf
management/management
client/client
client/client.exe
*.syso *.syso
client/.distfiles/ client/.distfiles/
infrastructure_files/setup.env infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode .vscode

View File

@@ -377,3 +377,8 @@ uploads:
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com username: dev@wiretrustee.com
method: PUT method: PUT
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh

View File

@@ -11,6 +11,8 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows
@@ -55,9 +57,6 @@ nfpms:
- src: client/ui/disconnected.png - src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- libayatana-appindicator3-1
- libgtk-3-dev
- libappindicator3-dev
- netbird - netbird
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
@@ -75,9 +74,6 @@ nfpms:
- src: client/ui/disconnected.png - src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- libayatana-appindicator3-1
- libgtk-3-dev
- libappindicator3-dev
- netbird - netbird
uploads: uploads:

View File

@@ -57,9 +57,10 @@ NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactiv
- \[x] Network Routes. - \[x] Network Routes.
- \[x] Private DNS. - \[x] Private DNS.
- \[x] Network Activity Monitoring. - \[x] Network Activity Monitoring.
- \[x] Mobile clients (Android).
-
**Coming soon:** **Coming soon:**
- \[ ] Mobile clients. - \[ ] Mobile clients (iOS).
### Secure peer-to-peer VPN with SSO and MFA in minutes ### Secure peer-to-peer VPN with SSO and MFA in minutes
@@ -70,9 +71,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
### Start using NetBird ### Start using NetBird
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/). - Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart). - See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting). - If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms. - Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
- Web UI [repository](https://github.com/netbirdio/dashboard). - Web UI [repository](https://github.com/netbirdio/dashboard).
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube. - 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
@@ -91,7 +92,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/> <img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
</p> </p>
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details. See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Roadmap ### Roadmap
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2) - [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -35,6 +36,11 @@ type RouteListener interface {
routemanager.RouteListener routemanager.RouteListener
} }
// DnsReadyListener export internal dns ReadyListener for mobile
type DnsReadyListener interface {
dns.ReadyListener
}
func init() { func init() {
formatter.SetLogcatFormatter(log.StandardLogger()) formatter.SetLogcatFormatter(log.StandardLogger())
} }
@@ -49,6 +55,7 @@ type Client struct {
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
routeListener routemanager.RouteListener routeListener routemanager.RouteListener
onHostDnsFn func([]string)
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener) c.onHostDnsFn = func([]string) {}
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()
if err != nil {
return err
}
dnsServer.OnUpdatedHostDNSServer(list.items)
return nil
}
// SetConnectionListener set the network connection listener // SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener) c.recorder.SetConnectionListener(listener)

View File

@@ -0,0 +1,26 @@
package android
import "fmt"
// DNSList is a wrapper of []string
type DNSList struct {
items []string
}
// Add new DNS address to the collection
func (array *DNSList) Add(s string) {
array.items = append(array.items, s)
}
// Get return an element of the collection
func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range")
}
return array.items[i], nil
}
// Size return with the size of the collection
func (array *DNSList) Size() int {
return len(array.items)
}

View File

@@ -0,0 +1,24 @@
package android
import "testing"
func TestDNSList_Get(t *testing.T) {
l := DNSList{
items: make([]string, 1),
}
_, err := l.Get(0)
if err != nil {
t.Errorf("invalid error: %s", err)
}
_, err = l.Get(-1)
if err == nil {
t.Errorf("expected error but got nil")
}
_, err = l.Get(1)
if err == nil {
t.Errorf("expected error but got nil")
}
}

View File

@@ -6,15 +6,14 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
) )
// SSOListener is async listener for mobile framework // SSOListener is async listener for mobile framework
@@ -87,9 +86,15 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
supportsSSO = false _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
err = nil if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
supportsSSO = false
err = nil
}
return err
} }
return err return err
}) })
@@ -183,27 +188,15 @@ func (a *Auth) login(urlOpener URLOpener) error {
return nil return nil
} }
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
if err != nil { if err != nil {
s, ok := gstatus.FromError(err) return nil, err
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
} }
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil { if err != nil {
return nil, fmt.Errorf("getting a request device code failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
go urlOpener.Open(flowInfo.VerificationURIComplete) go urlOpener.Open(flowInfo.VerificationURIComplete)
@@ -211,7 +204,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo,
waitTimeout := time.Duration(flowInfo.ExpiresIn) waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second) waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
defer cancel() defer cancel()
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo) tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err) return nil, fmt.Errorf("waiting for browser login failed: %v", err)
} }

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/auth"
"strings" "strings"
"time" "time"
@@ -163,31 +164,15 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
if err != nil { if err != nil {
s, ok := gstatus.FromError(err) return nil, err
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
mgmtURL := managementURL
if mgmtURL == "" {
mgmtURL = internal.DefaultManagementURL
}
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your servver or use Setup Keys to login", mgmtURL)
} else {
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
}
} }
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil { if err != nil {
return nil, fmt.Errorf("getting a request device code failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
@@ -196,7 +181,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second) waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
defer c() defer c()
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo) tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err) return nil, fmt.Errorf("waiting for browser login failed: %v", err)
} }
@@ -206,8 +191,10 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string var codeMsg string
if !strings.Contains(verificationURIComplete, userCode) { if userCode != "" {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) if !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
} }
err := open.Run(verificationURIComplete) err := open.Run(verificationURIComplete)

View File

@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
go func() { go func() {
// blocking // blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Print(err) log.Debug(err)
os.Exit(1)
} }
cancel() cancel()
}() }()
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey) c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil { if err != nil {
cmd.Printf("Error: %v\n", err) cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. " + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + "You can verify the connection by running:\n\n" +
"Run the status command: \n\n" + " netbird status\n\n")
" netbird status\n\n" + return err
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
return nil
} }
go func() { go func() {
<-ctx.Done() <-ctx.Done()

View File

@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil) return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@@ -51,6 +51,7 @@ type Manager interface {
dPort *Port, dPort *Port,
direction RuleDirection, direction RuleDirection,
action Action, action Action,
ipsetName string,
comment string, comment string,
) (Rule, error) ) (Rule, error)
@@ -60,5 +61,8 @@ type Manager interface {
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset() error
// Flush the changes to firewall controller
Flush() error
// TODO: migrate routemanager firewal actions to this interface // TODO: migrate routemanager firewal actions to this interface
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall"
@@ -35,6 +36,8 @@ type Manager struct {
inputDefaultRuleSpecs []string inputDefaultRuleSpecs []string
outputDefaultRuleSpecs []string outputDefaultRuleSpecs []string
wgIface iFaceMapper wgIface iFaceMapper
rulesets map[string]ruleset
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -43,6 +46,11 @@ type iFaceMapper interface {
Address() iface.WGAddress Address() iface.WGAddress
} }
type ruleset struct {
rule *Rule
ips map[string]string
}
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
@@ -51,6 +59,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()}, "-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
outputDefaultRuleSpecs: []string{ outputDefaultRuleSpecs: []string{
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()}, "-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
rulesets: make(map[string]ruleset),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
} }
// init clients for booth ipv4 and ipv6 // init clients for booth ipv4 and ipv6
@@ -58,13 +71,17 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
m.ipv4Client = ipv4Client if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil { if err != nil {
log.Errorf("ip6tables is not installed in the system or not supported: %v", err) log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else { } else {
m.ipv6Client = ipv6Client if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
} }
if err := m.Reset(); err != nil { if err := m.Reset(); err != nil {
@@ -73,6 +90,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment is empty rule ID is used as comment // If comment is empty rule ID is used as comment
@@ -83,6 +105,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
@@ -101,22 +124,45 @@ func (m *Manager) AddFiltering(
if sPort != nil && sPort.Values != nil { if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0]) sPortVal = strconv.Itoa(sPort.Values[0])
} }
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String() ruleID := uuid.New().String()
if comment == "" { if comment == "" {
comment = ruleID comment = ruleID
} }
specs := m.filterRuleSpecs( if ipsetName != "" {
"filter", rs, rsExists := m.rulesets[ipsetName]
ip, if !rsExists {
string(protocol), if err := ipset.Flush(ipsetName); err != nil {
sPortVal, log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
dPortVal, }
direction, if err := ipset.Create(ipsetName); err != nil {
action, return nil, fmt.Errorf("failed to create ipset: %w", err)
comment, }
) }
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// this is new ipset so we need to create firewall rule for it
}
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
direction, action, comment, ipsetName)
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...) ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
@@ -144,12 +190,24 @@ func (m *Manager) AddFiltering(
} }
} }
return &Rule{ rule := &Rule{
id: ruleID, ruleID: ruleID,
specs: specs, specs: specs,
dst: direction == fw.RuleDirectionOUT, ipsetName: ipsetName,
v6: ip.To4() == nil, ip: ip.String(),
}, nil dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
// for it, need to assosiate it with ruleset
m.rulesets[ipsetName] = ruleset{
rule: rule,
ips: map[string]string{rule.ip: ruleID},
}
}
return rule, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
@@ -170,6 +228,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
client = m.ipv6Client client = m.ipv6Client
} }
if rs, ok := m.rulesets[r.ipsetName]; ok {
// delete IP from ruleset IPs list and ipset
if _, ok := rs.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(rs.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(rs.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and assosiated firewall rule too
delete(m.rulesets, r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
r = rs.rule
}
if r.dst { if r.dst {
return client.Delete("filter", ChainOutputFilterName, r.specs...) return client.Delete("filter", ChainOutputFilterName, r.specs...)
} }
@@ -193,6 +276,9 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it // reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error { func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName) ok, err := client.ChainExists(table, ChainInputFilterName)
@@ -233,6 +319,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
return nil return nil
} }
for ipsetName := range m.rulesets {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
delete(m.rulesets, ipsetName)
}
return nil return nil
} }
@@ -240,6 +336,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
func (m *Manager) filterRuleSpecs( func (m *Manager) filterRuleSpecs(
table string, ip net.IP, protocol string, sPort, dPort string, table string, ip net.IP, protocol string, sPort, dPort string,
direction fw.RuleDirection, action fw.Action, comment string, direction fw.RuleDirection, action fw.Action, comment string,
ipsetName string,
) (specs []string) { ) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
@@ -249,11 +346,19 @@ func (m *Manager) filterRuleSpecs(
switch direction { switch direction {
case fw.RuleDirectionIN: case fw.RuleDirectionIN:
if matchByIP { if matchByIP {
specs = append(specs, "-s", ip.String()) if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
} }
case fw.RuleDirectionOUT: case fw.RuleDirectionOUT:
if matchByIP { if matchByIP {
specs = append(specs, "-d", ip.String()) if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
} }
} }
if protocol != "all" { if protocol != "all" {
@@ -335,3 +440,16 @@ func (m *Manager) actionToStr(action fw.Action) string {
} }
return "DROP" return "DROP"
} }
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
if ipsetName == "" {
return ""
} else if sPort != "" && dPort != "" {
return ipsetName + "-sport-dport"
} else if sPort != "" {
return ipsetName + "-sport"
} else if dPort != "" {
return ipsetName + "-dport"
}
return ipsetName
}

View File

@@ -55,12 +55,13 @@ func TestIptablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { err := manager.Reset()
t.Errorf("clear the manager state: %v", err) require.NoError(t, err, "clear the manager state")
}
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -68,7 +69,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
@@ -81,33 +82,31 @@ func TestIptablesManager(t *testing.T) {
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
}) })
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
if err := manager.DeleteRule(rule1); err != nil { err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
}
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
}) })
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
if err := manager.DeleteRule(rule2); err != nil { err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
}
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...) require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic") _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset()
@@ -122,6 +121,88 @@ func TestIptablesManager(t *testing.T) {
}) })
} }
func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
}()
var rule1 fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
})
var rule2 fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
}
rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
require.NoError(t, err, "failed to add rule")
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
})
t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
})
t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
})
}
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) { func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
exists, err := ipv4Client.Exists("filter", chainName, rulespec...) exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
require.NoError(t, err, "failed to check rule") require.NoError(t, err, "failed to check rule")
@@ -153,9 +234,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { err := manager.Reset()
t.Errorf("clear the manager state: %v", err) require.NoError(t, err, "clear the manager state")
}
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -167,9 +248,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -2,13 +2,16 @@ package iptables
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
id string ruleID string
ipsetName string
specs []string specs []string
ip string
dst bool dst bool
v6 bool v6 bool
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) GetRuleID() string {
return r.id return r.ruleID
} }

View File

@@ -6,12 +6,14 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/google/uuid" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall"
@@ -29,11 +31,14 @@ const (
FilterOutputChainName = "netbird-acl-output-filter" FilterOutputChainName = "netbird-acl-output-filter"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
conn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn
tableIPv4 *nftables.Table tableIPv4 *nftables.Table
tableIPv6 *nftables.Table tableIPv6 *nftables.Table
@@ -43,6 +48,10 @@ type Manager struct {
filterInputChainIPv6 *nftables.Chain filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
wgIface iFaceMapper wgIface iFaceMapper
} }
@@ -54,8 +63,23 @@ type iFaceMapper interface {
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
m := &Manager{ m := &Manager{
conn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn,
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
wgIface: wgIface, wgIface: wgIface,
} }
@@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
@@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
var ( var (
err error err error
ipset *nftables.Set
table *nftables.Table table *nftables.Table
chain *nftables.Chain chain *nftables.Chain
) )
@@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
return nil, err return nil, err
} }
rawIP := ip.To4()
if rawIP == nil {
rawIP = ip.To16()
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME ifaceKey = expr.MetaKeyOIFNAME
@@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
}) })
} }
// don't use IP matching if IP is ip 0.0.0.0 // check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
if s := ip.String(); s != "0.0.0.0" && s != "::" { // in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
var adrLen, adrOffset uint32 addrLen := uint32(len(rawIP))
if ip.To4() == nil { addrOffset := uint32(12)
adrLen = 16 if addrLen == 16 {
adrOffset = 8 addrOffset = 8
} else {
adrLen = 4
adrOffset = 12
} }
// change to destination address position if need // change to destination address position if need
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
adrOffset += adrLen addrOffset += addrLen
} }
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: adrOffset, Offset: addrOffset,
Len: adrLen, Len: addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
}, },
) )
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
} }
if sPort != nil && len(sPort.Values) != 0 { if sPort != nil && len(sPort.Values) != 0 {
@@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
id := uuid.New().String() userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
userData := []byte(strings.Join([]string{id, comment}, " "))
_ = m.conn.InsertRule(&nftables.Rule{ rule := m.rConn.InsertRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Position: 0, Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
if err := m.rConn.Flush(); err != nil {
if err := m.conn.Flush(); err != nil { return nil, fmt.Errorf("flush insert rule: %v", err)
return nil, err
} }
list, err := m.conn.GetRules(table, chain) ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
if err != nil { return m.rulesetManager.addRule(ruleset, rawIP)
return nil, err }
// getRulesetID returns ruleset ID based on given parameters
func (m *Manager) getRulesetID(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
}
// createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
} }
// Add the rule to the chain if err := m.rConn.AddSet(ipset, nil); err != nil {
rule := &Rule{id: id} return nil, fmt.Errorf("create set: %v", err)
for _, r := range list {
if bytes.Equal(r.UserData, userData) {
rule.Rule = r
break
}
}
if rule.Rule == nil {
return nil, fmt.Errorf("rule not found")
} }
return rule, nil if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
} }
// chain returns the chain for the given IP address with specific settings // chain returns the chain for the given IP address with specific settings
@@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
} }
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.conn.ListTablesOfFamily(family) tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
@@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
} }
} }
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
} }
func (m *Manager) createChainIfNotExists( func (m *Manager) createChainIfNotExists(
@@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
return nil, err return nil, err
} }
chains, err := m.conn.ListChainsOfTableFamily(family) chains, err := m.rConn.ListChainsOfTableFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of chains: %w", err) return nil, fmt.Errorf("list of chains: %w", err)
} }
@@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
Policy: &polAccept, Policy: &polAccept,
} }
chain = m.conn.AddChain(chain) chain = m.rConn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0 shiftDSTAddr := 0
@@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
) )
} }
_ = m.conn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: expressions,
@@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
_ = m.conn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: expressions,
}) })
if err := m.conn.Flush(); err != nil {
if err := m.rConn.Flush(); err != nil {
return nil, err return nil, err
} }
@@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error { func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule) nativeRule, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if err := m.conn.DelRule(nativeRule.Rule); err != nil { if nativeRule.nftRule == nil {
return err return nil
} }
return m.conn.Flush() if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return err
}
nativeRule.nftRule = nil
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
chains, err := m.conn.ListChains() chains, err := m.rConn.ListChains()
if err != nil { if err != nil {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
for _, c := range chains { for _, c := range chains {
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.conn.DelChain(c) m.rConn.DelChain(c)
} }
} }
tables, err := m.conn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == FilterTableName { if t.Name == FilterTableName {
m.conn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return m.conn.Flush() return m.rConn.Flush()
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.flushWithBackoff(); err != nil {
return err
}
// set must be removed after flush rule changes
// otherwise we will get error
for _, s := range m.setRemoved {
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
if len(m.setRemoved) > 0 {
if err := m.flushWithBackoff(); err != nil {
return err
}
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime = backoffTime * 2
continue
}
break
}
return
}
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
if table == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(table, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) != 0 {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
log.Errorf("failed to set rule handle: %v", err)
}
}
}
return nil
} }
func encodePort(port fw.Port) []byte { func encodePort(port fw.Port) []byte {

View File

@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset() err = manager.Reset()
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
fw.RuleDirectionIN, fw.RuleDirectionIN,
fw.ActionDrop, fw.ActionDrop,
"", "",
"",
) )
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: // test expectations:
// 1) regular rule // 1) regular rule
// 2) "accept extra routed traffic rule" for the interface // 2) "accept extra routed traffic rule" for the interface
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
err = manager.DeleteRule(rule) err = manager.DeleteRule(rule)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: // test expectations:
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(); err != nil {
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
err = manager.Flush()
require.NoError(t, err, "failed to flush")
}
} }
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax)) t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
}) })
} }

View File

@@ -6,11 +6,14 @@ import (
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
*nftables.Rule nftRule *nftables.Rule
id string nftSet *nftables.Set
ruleID string
ip []byte
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) GetRuleID() string {
return r.id return r.ruleID
} }

View File

@@ -0,0 +1,115 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@@ -0,0 +1,122 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@@ -1,5 +1,9 @@
package firewall package firewall
import (
"strconv"
)
// Protocol is the protocol of the port // Protocol is the protocol of the port
type Protocol string type Protocol string
@@ -28,3 +32,15 @@ type Port struct {
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports // Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []int Values []int
} }
// String interface implementation
func (p *Port) String() string {
var ports string
for _, port := range p.Values {
if ports != "" {
ports += ","
}
ports += strconv.Itoa(port)
}
return ports
}

View File

@@ -21,11 +21,13 @@ type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
} }
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules []Rule outgoingRules map[string]RuleSet
incomingRules []Rule incomingRules map[string]RuleSet
rulesIndex map[string]int
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@@ -48,7 +50,6 @@ type decoder struct {
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface IFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rulesIndex: make(map[string]int),
decoders: sync.Pool{ decoders: sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
return d return d
}, },
}, },
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@@ -81,6 +84,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
r := Rule{ r := Rule{
@@ -124,15 +128,17 @@ func (m *Manager) AddFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
var p int
if direction == fw.RuleDirectionIN { if direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules, r) if _, ok := m.incomingRules[r.ip.String()]; !ok {
p = len(m.incomingRules) - 1 m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
m.outgoingRules = append(m.outgoingRules, r) if _, ok := m.outgoingRules[r.ip.String()]; !ok {
p = len(m.outgoingRules) - 1 m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
m.rulesIndex[r.id] = p
m.mutex.Unlock() m.mutex.Unlock()
return &r, nil return &r, nil
@@ -148,24 +154,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
p, ok := m.rulesIndex[r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.rulesIndex, r.id)
var toUpdate []Rule
if r.direction == fw.RuleDirectionIN { if r.direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...) _, ok := m.incomingRules[r.ip.String()][r.id]
toUpdate = m.incomingRules if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else { } else {
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...) _, ok := m.outgoingRules[r.ip.String()][r.id]
toUpdate = m.outgoingRules if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
} }
for i := 0; i < len(toUpdate); i++ {
m.rulesIndex[toUpdate[i].id] = i
}
return nil return nil
} }
@@ -174,13 +176,15 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = m.outgoingRules[:0] m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = m.incomingRules[:0] m.incomingRules = make(map[string]RuleSet)
m.rulesIndex = make(map[string]int)
return nil return nil
} }
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.dropFilter(packetData, m.outgoingRules, false) return m.dropFilter(packetData, m.outgoingRules, false)
@@ -192,7 +196,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
} }
// dropFilter imlements same logic for booth direction of the traffic // dropFilter imlements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool { func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@@ -224,37 +228,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
log.Errorf("unknown layer: %v", d.decoded[0]) log.Errorf("unknown layer: %v", d.decoded[0])
return true return true
} }
payloadLayer := d.decoded[1]
// check if IP address match by IP var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
}
// default policy is DROP ALL
return true
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP { if rule.matchByIP && !ip.Equal(rule.ip) {
switch ipLayer { continue
case layers.LayerTypeIPv4:
if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip4.DstIP.Equal(rule.ip) {
continue
}
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip6.DstIP.Equal(rule.ip) {
continue
}
}
}
} }
if rule.protoLayer == layerTypeAll { if rule.protoLayer == layerTypeAll {
return rule.drop return rule.drop, true
} }
if payloadLayer != rule.protoLayer { if payloadLayer != rule.protoLayer {
@@ -264,38 +280,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
switch payloadLayer { switch payloadLayer {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 { if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop return rule.drop, true
} }
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) { if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop return rule.drop, true
} }
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
return rule.drop return rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule) // if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook // we ignore rule.drop and call this hook
if rule.udpHook != nil { if rule.udpHook != nil {
return rule.udpHook(packetData) return rule.udpHook(packetData), true
} }
if rule.sPort == 0 && rule.dPort == 0 { if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop return rule.drop, true
} }
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) { if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop return rule.drop, true
} }
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop return rule.drop, true
} }
return rule.drop return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop return rule.drop, true
} }
} }
return false, false
// default policy is DROP ALL
return true
} }
// SetNetwork of the wireguard interface to which filtering applied // SetNetwork of the wireguard interface to which filtering applied
@@ -325,19 +339,19 @@ func (m *Manager) AddUDPPacketHook(
} }
m.mutex.Lock() m.mutex.Lock()
var toUpdate []Rule
if in { if in {
r.direction = fw.RuleDirectionIN r.direction = fw.RuleDirectionIN
m.incomingRules = append([]Rule{r}, m.incomingRules...) if _, ok := m.incomingRules[r.ip.String()]; !ok {
toUpdate = m.incomingRules m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
m.outgoingRules = append([]Rule{r}, m.outgoingRules...) if _, ok := m.outgoingRules[r.ip.String()]; !ok {
toUpdate = m.outgoingRules m.outgoingRules[r.ip.String()] = make(map[string]Rule)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
for i := range toUpdate {
m.rulesIndex[toUpdate[i].id] = i
}
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@@ -345,14 +359,18 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID // RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error { func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range m.incomingRules { for _, arr := range m.incomingRules {
if r.id == hookID { for _, r := range arr {
return m.DeleteRule(&r) if r.id == hookID {
return m.DeleteRule(&r)
}
} }
} }
for _, r := range m.outgoingRules { for _, arr := range m.outgoingRules {
if r.id == hookID { for _, r := range arr {
return m.DeleteRule(&r) if r.id == hookID {
return m.DeleteRule(&r)
}
} }
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")

View File

@@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 { if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the rulesIndex") t.Errorf("rule2 is not in the incomingRules")
} }
err = m.DeleteRule(rule2) err = m.DeleteRule(rule2)
@@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 { if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule1 still in the rulesIndex") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager := &Manager{ manager := &Manager{
incomingRules: []Rule{}, incomingRules: map[string]RuleSet{},
outgoingRules: []Rule{}, outgoingRules: map[string]RuleSet{},
rulesIndex: make(map[string]int),
} }
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule var addedRule Rule
if tt.in { if tt.in {
if len(manager.incomingRules) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
addedRule = manager.incomingRules[0] for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else { } else {
if len(manager.outgoingRules) != 1 { if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
addedRule = manager.outgoingRules[0] for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
} }
if !tt.ip.Equal(addedRule.ip) { if !tt.ip.Equal(addedRule.ip) {
@@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
} }
// Ensure rulesIndex is correctly updated
index, ok := manager.rulesIndex[addedRule.id]
if !ok {
t.Errorf("expected rule to be in rulesIndex")
return
}
if index != 0 {
t.Errorf("expected rule index to be 0, got %d", index)
return
}
}) })
} }
} }
@@ -244,7 +236,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment) _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rules is not empty") t.Errorf("rules is not empty")
} }
} }
@@ -282,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment) _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
for _, rule := range manager.outgoingRules { for _, arr := range manager.outgoingRules {
if rule.id == hookID { for _, rule := range arr {
found = true if rule.id == hookID {
break found = true
break
}
} }
} }
@@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
} }
// Assert the hook is removed by checking it in the manager's outgoing rules // Assert the hook is removed by checking it in the manager's outgoing rules
for _, rule := range manager.outgoingRules { for _, arr := range manager.outgoingRules {
if rule.id == hookID { for _, rule := range arr {
t.Fatalf("The hook was not removed properly.") if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
} }
} }
} }
@@ -394,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -1,10 +1,13 @@
package acl package acl
import ( import (
"crypto/md5"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -30,9 +33,22 @@ type Manager interface {
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
manager firewall.Manager manager firewall.Manager
rulesPairs map[string][]firewall.Rule ipsetCounter int
mutex sync.Mutex rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
} }
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
@@ -42,11 +58,28 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
start := time.Now()
defer func() {
total := 0
for _, pairs := range d.rulesPairs {
total += len(pairs)
}
log.Infof(
"ACL rules processed in: %v, total rules count: %d",
time.Since(start), total)
}()
if d.manager == nil { if d.manager == nil {
log.Debug("firewall manager is not supported, skipping firewall rules") log.Debug("firewall manager is not supported, skipping firewall rules")
return return
} }
defer func() {
if err := d.manager.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := (networkMap.PeerConfig != nil && enableSSH := (networkMap.PeerConfig != nil &&
@@ -94,14 +127,38 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
applyFailed := false applyFailed := false
newRulePairs := make(map[string][]firewall.Rule) newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules { for _, r := range rules {
rulePair, err := d.protoRuleToFirewallRule(r) selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
}
ipsetName = ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true applyFailed = true
break break
} }
newRulePairs[rulePair[0].GetRuleID()] = rulePair newRulePairs[pairID] = rulePair
} }
if applyFailed { if applyFailed {
log.Error("failed to apply firewall rules, rollback ACL to previous state") log.Error("failed to apply firewall rules, rollback ACL to previous state")
@@ -140,55 +197,71 @@ func (d *DefaultManager) Stop() {
} }
} }
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]firewall.Rule, error) { func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP) ip := net.ParseIP(r.PeerIP)
if ip == nil { if ip == nil {
return nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
} }
protocol := convertToFirewallProtocol(r.Protocol) protocol := convertToFirewallProtocol(r.Protocol)
if protocol == firewall.ProtocolUnknown { if protocol == firewall.ProtocolUnknown {
return nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol) return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
} }
action := convertFirewallAction(r.Action) action := convertFirewallAction(r.Action)
if action == firewall.ActionUnknown { if action == firewall.ActionUnknown {
return nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action) return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
} }
var port *firewall.Port var port *firewall.Port
if r.Port != "" { if r.Port != "" {
value, err := strconv.Atoi(r.Port) value, err := strconv.Atoi(r.Port)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid port, skipping firewall rule") return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
} }
port = &firewall.Port{ port = &firewall.Port{
Values: []int{value}, Values: []int{value},
} }
} }
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule var rules []firewall.Rule
var err error var err error
switch r.Direction { switch r.Direction {
case mgmProto.FirewallRule_IN: case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT: case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
if err != nil { if err != nil {
return nil, err return "", nil, err
} }
d.rulesPairs[rules[0].GetRuleID()] = rules d.rulesPairs[ruleID] = rules
return rules, nil return ruleID, rules, nil
} }
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { func (d *DefaultManager) addInRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment) rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -198,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment) rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -206,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
return append(rules, rule), nil return append(rules, rule), nil
} }
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { func (d *DefaultManager) addOutRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment) rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -218,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment) rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@@ -226,6 +309,23 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
return append(rules, rule), nil return append(rules, rule), nil
} }
// getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getRuleID(
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
) string {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic. // to all peers in the network map to one rule which just accepts that type of the traffic.
// //
@@ -235,7 +335,7 @@ func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap, networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { ) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
totalIPs := 0 totalIPs := 0
for _, p := range networkMap.RemotePeers { for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps { for range p.AllowedIps {
totalIPs++ totalIPs++
} }
@@ -246,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
in := protoMatch{} in := protoMatch{}
out := protoMatch{} out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not. // this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list. // We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if: // But we zeroed the IP's for protocol if:
@@ -262,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = map[string]int{}
} }
match := protocols[r.Protocol]
if _, ok := match[r.PeerIP]; ok { // special case, when we recieve this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return return
} }
match[r.PeerIP] = i
ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
} }
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
@@ -288,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.FirewallRule_UDP, mgmProto.FirewallRule_UDP,
} }
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
@@ -346,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
return append(rules, squashedRules...), squashedProtocols return append(rules, squashedRules...), squashedProtocols
} }
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol { func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
switch protocol { switch protocol {
case mgmProto.FirewallRule_TCP: case mgmProto.FirewallRule_TCP:

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
) )
@@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &DefaultManager{ return newDefaultManager(fm), nil
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
} }
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
} }
} }
return &DefaultManager{ return newDefaultManager(fm), nil
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
} }

View File

@@ -55,6 +55,11 @@ func TestDefaultManager(t *testing.T) {
}) })
t.Run("add extra rules", func(t *testing.T) { t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
for id := range acl.rulesPairs {
existedPairs[id] = struct{}{}
}
// remove first rule // remove first rule
networkMap.FirewallRules = networkMap.FirewallRules[1:] networkMap.FirewallRules = networkMap.FirewallRules[1:]
networkMap.FirewallRules = append( networkMap.FirewallRules = append(
@@ -67,11 +72,6 @@ func TestDefaultManager(t *testing.T) {
}, },
) )
existedRulesID := map[string]struct{}{}
for id := range acl.rulesPairs {
existedRulesID[id] = struct{}{}
}
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules // we should have one old and one new rule in the existed rules
@@ -80,13 +80,16 @@ func TestDefaultManager(t *testing.T) {
return return
} }
// check that old rules was removed // check that old rule was removed
for id := range existedRulesID { previousCount := 0
if _, ok := acl.rulesPairs[id]; ok { for id := range acl.rulesPairs {
t.Errorf("old rule was not removed") if _, ok := existedPairs[id]; ok {
return previousCount++
} }
} }
if previousCount != 1 {
t.Errorf("old rule was not removed")
}
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {

View File

@@ -0,0 +1,202 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/client/internal"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// HostedGrantType grant type for device flow on Hosted
const (
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
)
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
type RequestDeviceCodePayload struct {
Audience string `json:"audience"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Hosted token's response
type TokenRequestResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
TokenInfo
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
return &DeviceAuthorizationFlow{
providerConfig: config,
HTTPClient: httpClient,
}, nil
}
// GetClientID returns the provider client id
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
form.Add("audience", d.providerConfig.Audience)
form.Add("scope", d.providerConfig.Scope)
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := d.HTTPClient.Do(req)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
}
if res.StatusCode != 200 {
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
}
deviceCode := AuthFlowInfo{}
err = json.Unmarshal(body, &deviceCode)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
}
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := d.HTTPClient.Do(req)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
}
defer func() {
err := res.Body.Close()
if err != nil {
return
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
}
if res.StatusCode > 499 {
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
return tokenResponse, nil
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
UseIDToken: d.providerConfig.UseIDToken,
}
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, err
}
}
}

View File

@@ -1,17 +1,17 @@
package internal package auth
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/client/internal"
"github.com/stretchr/testify/require"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
) )
type mockHTTPClient struct { type mockHTTPClient struct {
@@ -53,7 +53,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingErrFunc require.ErrorAssertionFunc testingErrFunc require.ErrorAssertionFunc
expectedErrorMSG string expectedErrorMSG string
testingFunc require.ComparisonAssertionFunc testingFunc require.ComparisonAssertionFunc
expectedOut DeviceAuthInfo expectedOut AuthFlowInfo
expectedMSG string expectedMSG string
expectPayload string expectPayload string
} }
@@ -92,7 +92,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingFunc: require.EqualValues, testingFunc: require.EqualValues,
expectPayload: expectPayload, expectPayload: expectPayload,
} }
testCase4Out := DeviceAuthInfo{ExpiresIn: 10} testCase4Out := AuthFlowInfo{ExpiresIn: 10}
testCase4 := test{ testCase4 := test{
name: "Got Device Code", name: "Got Device Code",
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn), inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
@@ -113,8 +113,8 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError, err: testCase.inputReqError,
} }
hosted := Hosted{ deviceFlow := &DeviceAuthorizationFlow{
providerConfig: ProviderConfig{ providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience, Audience: expectedAudience,
ClientID: expectedClientID, ClientID: expectedClientID,
Scope: expectedScope, Scope: expectedScope,
@@ -125,7 +125,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
HTTPClient: &httpClient, HTTPClient: &httpClient,
} }
authInfo, err := hosted.RequestDeviceCode(context.TODO()) authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match") require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
@@ -145,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
inputMaxReqs int inputMaxReqs int
inputCountResBody string inputCountResBody string
inputTimeout time.Duration inputTimeout time.Duration
inputInfo DeviceAuthInfo inputInfo AuthFlowInfo
inputAudience string inputAudience string
testingErrFunc require.ErrorAssertionFunc testingErrFunc require.ErrorAssertionFunc
expectedErrorMSG string expectedErrorMSG string
@@ -155,7 +155,7 @@ func TestHosted_WaitToken(t *testing.T) {
expectPayload string expectPayload string
} }
defaultInfo := DeviceAuthInfo{ defaultInfo := AuthFlowInfo{
DeviceCode: "test", DeviceCode: "test",
ExpiresIn: 10, ExpiresIn: 10,
Interval: 1, Interval: 1,
@@ -278,8 +278,8 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody, countResBody: testCase.inputCountResBody,
} }
hosted := Hosted{ deviceFlow := DeviceAuthorizationFlow{
providerConfig: ProviderConfig{ providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience, Audience: testCase.inputAudience,
ClientID: clientID, ClientID: clientID,
TokenEndpoint: "test.hosted.com/token", TokenEndpoint: "test.hosted.com/token",
@@ -287,11 +287,12 @@ func TestHosted_WaitToken(t *testing.T) {
Scope: "openid", Scope: "openid",
UseIDToken: false, UseIDToken: false,
}, },
HTTPClient: &httpClient} HTTPClient: &httpClient,
}
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel() defer cancel()
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo) tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match") require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")

View File

@@ -0,0 +1,90 @@
package auth
import (
"context"
"fmt"
"net/http"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
)
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
type OAuthFlow interface {
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
// HTTPClient http client interface for API calls
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
type AuthFlowInfo struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// Claims used when validating the access token
type Claims struct {
Audience interface{} `json:"aud"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
func (t TokenInfo) GetTokenToUse() string {
if t.UseIDToken {
return t.IDToken
}
return t.AccessToken
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("getting device authorization flow info")
// Try to initialize the Device Authorization Flow
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}
log.Debugf("getting device authorization flow info failed with error: %v", err)
log.Debugf("falling back to pkce authorization flow info")
// If Device Authorization Flow failed, try the PKCE Authorization Flow
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
} else {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}

View File

@@ -0,0 +1,238 @@
package auth
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"html/template"
"net"
"net/http"
"net/url"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
const (
queryState = "state"
queryCode = "code"
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
for _, redirectURL := range config.RedirectURLs {
if !isRedirectURLPortUsed(redirectURL) {
availableRedirectURL = redirectURL
break
}
}
if availableRedirectURL == "" {
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
}
cfg := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthorizationEndpoint,
TokenURL: config.TokenEndpoint,
},
RedirectURL: availableRedirectURL,
Scopes: strings.Split(config.Scope, " "),
}
return &PKCEAuthorizationFlow{
providerConfig: config,
oAuthConfig: cfg,
}, nil
}
// GetClientID returns the provider client id
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
return p.providerConfig.ClientID
}
// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
}
p.state = state
codeVerifier, err := randomBytesInHex(64)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
}
p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
)
return AuthFlowInfo{
VerificationURIComplete: authURL,
ExpiresIn: defaultPKCETimeoutSeconds,
}, nil
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
go p.startServer(tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.handleOAuthToken(token)
case err := <-errChan:
return TokenInfo{}, err
}
}
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
return
}
port := parsedURL.Port()
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
defer func() {
if err := server.Shutdown(context.Background()); err != nil {
log.Errorf("error while shutting down pkce flow server: %v", err)
}
}()
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
state := query.Get(queryState)
// Prevent timing attacks on state
if subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
if err != nil {
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
renderPKCEFlowTmpl(w, err)
}
tokenChan <- token
renderPKCEFlowTmpl(w, nil)
})
if err := server.ListenAndServe(); err != nil {
errChan <- err
}
}
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenType: token.TokenType,
ExpiresIn: token.Expiry.Second(),
UseIDToken: p.providerConfig.UseIDToken,
}
if idToken, ok := token.Extra("id_token").(string); ok {
tokenInfo.IDToken = idToken
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, nil
}
func createCodeChallenge(codeVerifier string) string {
sha2 := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(sha2[:])
}
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
func isRedirectURLPortUsed(redirectURL string) bool {
parsedURL, err := url.Parse(redirectURL)
if err != nil {
log.Errorf("failed to parse redirect URL: %v", err)
return true
}
addr := fmt.Sprintf(":%s", parsedURL.Port())
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("error while closing the connection: %v", err)
}
}()
return true
}
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data := make(map[string]string)
if authError != nil {
data["Error"] = authError.Error()
}
if err := tmpl.Execute(w, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"reflect"
"strings"
)
func randomBytesInHex(count int) (string, error) {
buf := make([]byte, count)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
}
return hex.EncodeToString(buf), nil
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
return fmt.Errorf("token received is empty")
}
encodedClaims := strings.Split(token, ".")[1]
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
if err != nil {
return err
}
claims := Claims{}
err = json.Unmarshal(claimsString, &claims)
if err != nil {
return err
}
if claims.Audience == nil {
return fmt.Errorf("required token field audience is absent")
}
// Audience claim of JWT can be a string or an array of strings
typ := reflect.TypeOf(claims.Audience)
switch typ.Kind() {
case reflect.String:
if claims.Audience == audience {
return nil
}
case reflect.Slice:
for _, aud := range claims.Audience.([]interface{}) {
if audience == aud {
return nil
}
}
}
return fmt.Errorf("invalid JWT token audience field")
}

View File

@@ -215,10 +215,12 @@ func update(input ConfigInput) (*Config, error) {
} }
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey { if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, updated to %s (old value %s)", if *input.PreSharedKey != "" {
*input.PreSharedKey, config.PreSharedKey) log.Infof("new pre-shared key provides, updated to %s (old value %s)",
config.PreSharedKey = *input.PreSharedKey *input.PreSharedKey, config.PreSharedKey)
refresh = true config.PreSharedKey = *input.PreSharedKey
refresh = true
}
} }
if config.SSHKey == "" { if config.SSHKey == "" {

View File

@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), managementURL) assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey) assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 4: existing config, but new managementURL has been provided -> update config // case 4: new empty pre-shared key config -> fetch it
newPreSharedKey := ""
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: path,
PreSharedKey: &newPreSharedKey,
})
if err != nil {
return
}
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 5: existing config, but new managementURL has been provided -> update config
newManagementURL := "https://test.newManagement.url:33071" newManagementURL := "https://test.newManagement.url:33071"
config, err = UpdateOrCreateConfig(ConfigInput{ config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: newManagementURL, ManagementURL: newManagementURL,

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -24,7 +25,24 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{})
}
// RunClientMobile with main logic on mobile system
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return wrapErr(err) return wrapErr(err)
} }
// in case of non Android os these variables will be nil engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)

View File

@@ -16,11 +16,11 @@ import (
// DeviceAuthorizationFlow represents Device Authorization Flow information // DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct { type DeviceAuthorizationFlow struct {
Provider string Provider string
ProviderConfig ProviderConfig ProviderConfig DeviceAuthProviderConfig
} }
// ProviderConfig has all attributes needed to initiate a device authorization flow // DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type ProviderConfig struct { type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id // ClientID An IDP application client id
ClientID string ClientID string
// ClientSecret An IDP application client secret // ClientSecret An IDP application client secret
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
deviceAuthorizationFlow := DeviceAuthorizationFlow{ deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(), Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: ProviderConfig{ ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(), Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(), ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(), ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
deviceAuthorizationFlow.ProviderConfig.Scope = "openid" deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
} }
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil { if err != nil {
return DeviceAuthorizationFlow{}, err return DeviceAuthorizationFlow{}, err
} }
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
return deviceAuthorizationFlow, nil return deviceAuthorizationFlow, nil
} }
func isProviderConfigValid(config ProviderConfig) error { func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" { if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience") return fmt.Errorf(errorMSGFormat, "Audience")

View File

@@ -1,13 +1,9 @@
package dns package dns
import (
"github.com/netbirdio/netbird/iface"
)
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }

View File

@@ -9,8 +9,6 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -34,7 +32,7 @@ type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
} }
func newHostManager(_ *iface.WGIface) (hostManager, error) { func newHostManager(_ WGIface) (hostManager, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil

View File

@@ -5,10 +5,10 @@ package dns
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os" "os"
"strings" "strings"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -25,7 +25,7 @@ const (
type osManagerType int type osManagerType int
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -6,8 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -33,7 +31,7 @@ type registryConfigurator struct {
existingSearchDomains []string existingSearchDomains []string
} }
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString() guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
return "" return ""
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
//TODO implement me
panic("implement me")
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface // UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil { if m.UpdateDNSServerFunc != nil {

View File

@@ -14,8 +14,6 @@ import (
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -8,8 +8,6 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@@ -18,7 +16,7 @@ type resolvconf struct {
ifaceName string ifaceName string
} }
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
return &resolvconf{ return &resolvconf{
ifaceName: wgInterface.Name(), ifaceName: wgInterface.Name(),
}, nil }, nil

View File

@@ -3,29 +3,20 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"math/big"
"net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
) )
const ( // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
defaultPort = 53 type ReadyListener interface {
customPort = 5053 OnReady()
defaultIP = "127.0.0.1" }
customIP = "127.0.0.153"
)
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
@@ -33,6 +24,7 @@ type Server interface {
Stop() Stop()
DnsIP() string DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
} }
type registeredHandlerMap map[string]handlerWithStop type registeredHandlerMap map[string]handlerWithStop
@@ -42,21 +34,19 @@ type DefaultServer struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
mux sync.Mutex mux sync.Mutex
fakeResolverWG sync.WaitGroup service service
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
localResolver *localResolver localResolver *localResolver
wgInterface *iface.WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
listenerIsRunning bool
runtimePort int
runtimeIP string
previousConfigHash uint64 previousConfigHash uint64
currentConfig hostDNSConfig currentConfig hostDNSConfig
customAddress *netip.AddrPort
enabled bool // permanent related properties
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
} }
type handlerWithStop interface { type handlerWithStop interface {
@@ -70,9 +60,7 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) { func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
mux := dns.NewServeMux()
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -82,34 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
addrPort = &parsedAddrPort addrPort = &parsedAddrPort
} }
ctx, stop := context.WithCancel(ctx) var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
setServerDns(ds)
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
server: &dns.Server{ service: dnsService,
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
customAddress: addrPort,
} }
if initialDnsCfg != nil { return defaultServer
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
}
defaultServer.evalRuntimeAddress()
return defaultServer, nil
} }
// Initialize instantiate host manager. It required to be initialized wginterface // Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) { func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -118,74 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
return nil return nil
} }
if s.permanent {
err = s.service.Listen()
if err != nil {
return err
}
}
s.hostManager, err = newHostManager(s.wgInterface) s.hostManager, err = newHostManager(s.wgInterface)
return return
} }
// listen runs the listener in a go routine // DnsIP returns the DNS resolver server IP address
func (s *DefaultServer) listen() { //
// nil check required in unit tests // When kernel space interface used it return real DNS server listener IP address
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { // For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
s.fakeResolverWG.Add(1)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
hookID := s.filterDNSTraffic()
s.fakeResolverWG.Wait()
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}()
return
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
}
func (s *DefaultServer) DnsIP() string { func (s *DefaultServer) DnsIP() string {
if !s.enabled { return s.service.RuntimeIP()
return ""
}
return s.runtimeIP
}
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
} }
// Stop stops the server // Stop stops the server
@@ -194,34 +141,30 @@ func (s *DefaultServer) Stop() {
defer s.mux.Unlock() defer s.mux.Unlock()
s.ctxCancel() s.ctxCancel()
err := s.hostManager.restoreHostDNS() if s.hostManager != nil {
if err != nil { err := s.hostManager.restoreHostDNS()
log.Error(err) if err != nil {
log.Error(err)
}
} }
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { s.service.Stop()
s.fakeResolverWG.Done()
}
err = s.stopListener()
if err != nil {
log.Error(err)
}
} }
func (s *DefaultServer) stopListener() error { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
if !s.listenerIsRunning { // It will be applied if the mgm server do not enforce DNS settings for root zone
return nil func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
} s.hostsDnsListLock.Lock()
defer s.hostsDnsListLock.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) s.hostsDnsList = hostsDnsList
defer cancel() _, ok := s.dnsMuxMap[nbdns.RootZone]
if ok {
err := s.server.ShutdownContext(ctx) log.Debugf("on new host DNS config but skip to apply it")
if err != nil { return
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
} }
return nil log.Debugf("update host DNS settings: %+v", hostsDnsList)
s.addHostRootZone()
} }
// UpdateDNSServer processes an update received from the management service // UpdateDNSServer processes an update received from the management service
@@ -272,16 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver // is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable { if update.ServiceEnable {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { _ = s.service.Listen()
s.fakeResolverWG.Done() } else if !s.permanent {
} else { s.service.Stop()
if err := s.stopListener(); err != nil {
log.Error(err)
}
}
} else if !s.listenerIsRunning {
s.listen()
} }
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
@@ -292,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err != nil { if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err) return fmt.Errorf("not applying dns update, error: %v", err)
} }
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords) s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
hostUpdate := s.currentConfig hostUpdate := s.currentConfig
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") "Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false hostUpdate.routeAll = false
@@ -405,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap) muxUpdateMap := make(registeredHandlerMap)
var isContainRootUpdate bool
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler) s.service.RegisterMux(update.domain, update.handler)
muxUpdateMap[update.domain] = update.handler muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop() existingHandler.stop()
} }
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
} }
for key, existingHandler := range s.dnsMuxMap { for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
existingHandler.stop() if !isContainRootUpdate && key == nbdns.RootZone {
s.deregisterMux(key) s.hostsDnsListLock.Lock()
s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop()
} else {
existingHandler.stop()
s.service.DeregisterMux(key)
}
} }
} }
@@ -448,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} }
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@@ -483,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
for i, item := range s.currentConfig.domains { for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found { if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true s.currentConfig.domains[i].disabled = true
s.deregisterMux(item.domain) s.service.DeregisterMux(item.domain)
removeIndex[item.domain] = i removeIndex[item.domain] = i
} }
} }
@@ -500,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue continue
} }
s.currentConfig.domains[i].disabled = false s.currentConfig.domains[i].disabled = false
s.registerMux(domain, handler) s.service.RegisterMux(domain, handler)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@@ -516,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
return return
} }
func (s *DefaultServer) filterDNSTraffic() string { func (s *DefaultServer) addHostRootZone() {
filter := s.wgInterface.GetFilter() handler := newUpstreamResolver(s.ctx)
if filter == nil { handler.upstreamServers = make([]string, len(s.hostsDnsList))
log.Error("can't set DNS filter, filter not initialized") for n, ua := range s.hostsDnsList {
return "" handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
} }
handler.deactivate = func() {}
firstLayerDecoder := layers.LayerTypeIPv4 handler.reactivate = func() {}
if s.wgInterface.Address().Network.IP.To4() == nil { s.service.RegisterMux(nbdns.RootZone, handler)
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
return
}
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
return
}
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}
func hasValidDnsServer(cfg *nbdns.Config) bool {
for _, c := range cfg.NameServerGroups {
if c.Primary {
return true
}
}
return false
} }

View File

@@ -0,0 +1,29 @@
package dns
import (
"fmt"
"sync"
)
var (
mutex sync.Mutex
server Server
)
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
func GetServerDns() (Server, error) {
mutex.Lock()
if server == nil {
mutex.Unlock()
return nil, fmt.Errorf("DNS server not instantiated yet")
}
s := server
mutex.Unlock()
return s, nil
}
func setServerDns(newServerServer Server) {
mutex.Lock()
server = newServerServer
defer mutex.Unlock()
}

View File

@@ -0,0 +1,24 @@
package dns
import (
"testing"
)
func TestGetServerDns(t *testing.T) {
_, err := GetServerDns()
if err == nil {
t.Errorf("invalid dns server instance")
}
srv := &MockServer{}
setServerDns(srv)
srvB, err := GetServerDns()
if err != nil {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
t.Errorf("missmatch dns instances")
}
}

View File

@@ -5,17 +5,59 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/miekg/dns" "github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
) )
type mocWGIface struct {
filter iface.PacketFilter
}
func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() iface.WGAddress {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{
IP: ip,
Network: network,
}
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
panic("implement me")
}
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
panic("implement me")
}
func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
w.filter = filter
return nil
}
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
{ {
Name: "peera.netbird.cloud", Name: "peera.netbird.cloud",
@@ -26,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
}, },
} }
func init() {
log.SetLevel(log.TraceLevel)
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
@@ -221,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -239,9 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
dnsServer.fakeResolverWG.Add(1)
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil { if err != nil {
@@ -276,6 +320,133 @@ func TestUpdateDNSServer(t *testing.T) {
} }
} }
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("crate and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) { func TestDNSServerStartStop(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -292,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort) dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
if err != nil {
dnsServer.hostManager = newNoopHostMocker() t.Fatalf("%v", err)
dnsServer.listen()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
} }
dnsServer.hostManager = newNoopHostMocker()
err = dnsServer.service.Listen()
if err != nil {
t.Fatalf("dns server is not running: %s", err)
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop() defer dnsServer.Stop()
err := dnsServer.localResolver.registerRecord(zoneRecords[0]) err = dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver) dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
@@ -314,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
d := net.Dialer{ d := net.Dialer{
Timeout: time.Second * 5, Timeout: time.Second * 5,
} }
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort) addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
conn, err := d.DialContext(ctx, network, addr) conn, err := d.DialContext(ctx, network, addr)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
@@ -349,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
dnsMux: dns.DefaultServeMux, service: newServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@@ -412,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
} }
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer { func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
mux := dns.NewServeMux() wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
var parsedAddrPort *netip.AddrPort var dnsList []string
if addrPort != "" { dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
parsed, err := netip.ParseAddrPort(addrPort) err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Fatal(err) t.Errorf("failed to initialize DNS server: %v", err)
} return
parsedAddrPort = &parsed }
defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
} }
dnsServer := &dns.Server{ update := nbdns.Config{
Net: "udp", ServiceEnable: true,
Handler: mux, CustomZones: []nbdns.CustomZone{
UDPSize: 65535, {
} Domain: "netbird.cloud",
Records: zoneRecords,
ctx, cancel := context.WithCancel(context.TODO()) },
},
ds := &DefaultServer{ NameServerGroups: []*nbdns.NameServerGroup{
ctx: ctx, {
ctxCancel: cancel, NameServers: []nbdns.NameServer{
server: dnsServer, {
dnsMux: mux, IP: netip.MustParseAddr("8.8.4.4"),
dnsMuxMap: make(registeredHandlerMap), NSType: nbdns.UDPNameServerType,
localResolver: &localResolver{ Port: 53,
registeredMap: make(registrationMap), },
},
Enabled: true,
Primary: true,
},
}, },
customAddress: parsedAddrPort,
}
ds.evalRuntimeAddress()
return ds
}
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
} }
for _, tt := range tests { err = dnsServer.UpdateDNSServer(1, update)
_, ipnet, err := net.ParseCIDR(tt.addr) if err != nil {
if err != nil { t.Errorf("failed to update dns server: %s", err)
t.Errorf("Error parsing CIDR: %v", err) }
return
}
lastIP := getLastIPFromNetwork(ipnet, 1) _, err = resolver.LookupHost(context.Background(), "netbird.io")
if lastIP != tt.ip { if err != nil {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) t.Errorf("failed to resolve: %s", err)
} }
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
update2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{},
}
err = dnsServer.UpdateDNSServer(2, update2)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
}
func TestDNSPermanent_matchOnly(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, nil
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
}
err = wgIface.Create()
if err != nil {
t.Fatalf("crate and init wireguard interface: %v", err)
return nil, err
}
pf, err := uspfilter.Create(wgIface)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err
}
err = wgIface.SetFilter(pf)
if err != nil {
t.Fatalf("set packet filter: %v", err)
return nil, err
}
return wgIface, nil
}
func newDnsResolver(ip string, port int) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 3,
}
addr := fmt.Sprintf("%s:%d", ip, port)
return d.DialContext(ctx, network, addr)
},
} }
} }

View File

@@ -0,0 +1,18 @@
package dns
import (
"github.com/miekg/dns"
)
const (
defaultPort = 53
)
type service interface {
Listen() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int
RuntimeIP() string
}

View File

@@ -0,0 +1,145 @@
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
type serviceViaListener struct {
wgInterface WGIface
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
runtimeIP string
runtimePort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
}
return s
}
func (s *serviceViaListener) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
return nil
}
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaListener) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}

View File

@@ -0,0 +1,139 @@
package dns
import (
"fmt"
"math/big"
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
runtimePort int
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
runtimePort: defaultPort,
}
return s
}
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return err
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@@ -0,0 +1,31 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -15,7 +15,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.Name()) iface, err := net.InterfaceByName(wgInterface.Name())
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -0,0 +1,14 @@
//go:build !windows
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
}

View File

@@ -0,0 +1,13 @@
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
GetInterfaceGUIDString() (string, error)
}

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -101,7 +101,8 @@ type Engine struct {
ctx context.Context ctx context.Context
wgInterface *iface.WGIface wgInterface *iface.WGIface
wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
udpMuxConn io.Closer udpMuxConn io.Closer
@@ -132,6 +133,7 @@ func NewEngine(
signalClient signal.Client, mgmClient mgm.Client, signalClient signal.Client, mgmClient mgm.Client,
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@@ -146,6 +148,7 @@ func NewEngine(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
} }
} }
@@ -190,23 +193,25 @@ func (e *Engine) Start() error {
} }
var routes []*route.Route var routes []*route.Route
var dnsCfg *nbdns.Config
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
routes, dnsCfg, err = e.readInitialSettings() routes, err = e.readInitialSettings()
if err != nil { if err != nil {
return err return err
} }
} if e.dnsServer == nil {
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
if e.dnsServer == nil { go e.mobileDep.DnsReadyListener.OnReady()
}
} else {
// todo fix custom address // todo fix custom address
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg) if e.dnsServer == nil {
if err != nil { e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
e.close() if err != nil {
return err e.close()
return err
}
} }
e.dnsServer = dnsServer
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
@@ -280,7 +285,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok { if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.GetConf().ProxyConfig.AllowedIps != strings.Join(p.AllowedIps, ",") { if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p) modified = append(modified, p)
continue continue
} }
@@ -605,6 +610,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// cleanup request, most likely our peer has been deleted // cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() { if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers() err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil { if err != nil {
return err return err
} }
@@ -624,6 +630,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err return err
} }
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys // update SSHServer by adding remote peer SSH keys
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() { for _, config := range networkMap.GetRemotePeers() {
@@ -759,17 +767,13 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
} }
e.peerConns[peerKey] = conn e.peerConns[peerKey] = conn
err = e.statusRecorder.AddPeer(peerKey) err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil { if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
} }
go e.connWorker(conn, peerKey) go e.connWorker(conn, peerKey)
} }
err := e.statusRecorder.UpdatePeerFQDN(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerKey, err)
}
return nil return nil
} }
@@ -794,9 +798,7 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
// we might have received new STUN and TURN servers meanwhile, so update them // we might have received new STUN and TURN servers meanwhile, so update them
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
conf := conn.GetConf() conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
conf.StunTurn = append(e.STUNs, e.TURNs...)
conn.UpdateConf(conf)
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
err := conn.Open() err := conn.Open()
@@ -825,9 +827,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
stunTurn = append(stunTurn, e.STUNs...) stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...) stunTurn = append(stunTurn, e.TURNs...)
proxyConfig := proxy.Config{ wgConfig := peer.WgConfig{
RemoteKey: pubKey, RemoteKey: pubKey,
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort), WgListenPort: e.config.WgPort,
WgInterface: e.wgInterface, WgInterface: e.wgInterface,
AllowedIps: allowedIPs, AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey, PreSharedKey: e.config.PreSharedKey,
@@ -844,13 +846,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
Timeout: timeout, Timeout: timeout,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
ProxyConfig: proxyConfig, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(), UserspaceBind: e.wgInterface.IsUserspaceBind(),
} }
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover) peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1007,6 +1009,10 @@ func (e *Engine) parseNATExternalIPMappings() []string {
} }
func (e *Engine) close() { func (e *Engine) close() {
if err := e.wgProxyFactory.Free(); err != nil {
log.Errorf("failed closing ebpf proxy: %s", err)
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil { if err := e.wgInterface.Close(); err != nil {
@@ -1046,14 +1052,13 @@ func (e *Engine) close() {
} }
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, error) {
netMap, err := e.mgmClient.GetNetworkMap() netMap, err := e.mgmClient.GetNetworkMap()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
routes := toRoutes(netMap.GetRoutes()) routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, nil
return routes, &dnsCfg, nil
} }
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {

View File

@@ -367,9 +367,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Errorf("expecting Engine.peerConns to contain peer %s", p) t.Errorf("expecting Engine.peerConns to contain peer %s", p)
} }
expectedAllowedIPs := strings.Join(p.AllowedIps, ",") expectedAllowedIPs := strings.Join(p.AllowedIps, ",")
if conn.GetConf().ProxyConfig.AllowedIps != expectedAllowedIPs { if conn.WgConfig().AllowedIps != expectedAllowedIPs {
t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(), t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(),
expectedAllowedIPs, conn.GetConf().ProxyConfig.AllowedIps) expectedAllowedIPs, conn.WgConfig().AllowedIps)
} }
} }
}) })

View File

@@ -1,6 +1,7 @@
package internal package internal
import ( import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -8,7 +9,9 @@ import (
// MobileDependency collect all dependencies for mobile platform // MobileDependency collect all dependencies for mobile platform
type MobileDependency struct { type MobileDependency struct {
TunAdapter iface.TunAdapter TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
RouteListener routemanager.RouteListener RouteListener routemanager.RouteListener
HostDNSAddresses []string
DnsReadyListener dns.ReadyListener
} }

View File

@@ -1,286 +0,0 @@
package internal
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
)
// OAuthClient is a OAuth client interface for various idp providers
type OAuthClient interface {
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
// HTTPClient http client interface for API calls
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// DeviceAuthInfo holds information for the OAuth device login flow
type DeviceAuthInfo struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// HostedGrantType grant type for device flow on Hosted
const (
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
HostedRefreshGrant = "refresh_token"
)
// Hosted client
type Hosted struct {
providerConfig ProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
type RequestDeviceCodePayload struct {
Audience string `json:"audience"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Hosted token's response
type TokenRequestResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
TokenInfo
}
// Claims used when validating the access token
type Claims struct {
Audience interface{} `json:"aud"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
func (t TokenInfo) GetTokenToUse() string {
if t.UseIDToken {
return t.IDToken
}
return t.AccessToken
}
// NewHostedDeviceFlow returns an Hosted OAuth client
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
return &Hosted{
providerConfig: config,
HTTPClient: httpClient,
}
}
// GetClientID returns the provider client id
func (h *Hosted) GetClientID(ctx context.Context) string {
return h.providerConfig.ClientID
}
// RequestDeviceCode requests a device code login flow information from Hosted
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
form := url.Values{}
form.Add("client_id", h.providerConfig.ClientID)
form.Add("audience", h.providerConfig.Audience)
form.Add("scope", h.providerConfig.Scope)
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := h.HTTPClient.Do(req)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
}
if res.StatusCode != 200 {
return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
}
deviceCode := DeviceAuthInfo{}
err = json.Unmarshal(body, &deviceCode)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
}
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
if deviceCode.VerificationURIComplete == "" {
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
return deviceCode, err
}
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", h.providerConfig.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := h.HTTPClient.Do(req)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
}
defer func() {
err := res.Body.Close()
if err != nil {
return
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
}
if res.StatusCode > 499 {
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
return tokenResponse, nil
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := h.requestToken(info)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
if tokenResponse.Error != "" {
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
}
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
UseIDToken: h.providerConfig.UseIDToken,
}
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, err
}
}
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
return fmt.Errorf("token received is empty")
}
encodedClaims := strings.Split(token, ".")[1]
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
if err != nil {
return err
}
claims := Claims{}
err = json.Unmarshal(claimsString, &claims)
if err != nil {
return err
}
if claims.Audience == nil {
return fmt.Errorf("required token field audience is absent")
}
// Audience claim of JWT can be a string or an array of strings
typ := reflect.TypeOf(claims.Audience)
switch typ.Kind() {
case reflect.String:
if claims.Audience == audience {
return nil
}
case reflect.Slice:
for _, aud := range claims.Audience.([]interface{}) {
if audience == aud {
return nil
}
}
}
return fmt.Errorf("invalid JWT token audience field")
}

View File

@@ -10,9 +10,10 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@@ -23,8 +24,18 @@ import (
const ( const (
iceKeepAliveDefault = 4 * time.Second iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second
defaultWgKeepAlive = 25 * time.Second
) )
type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
// ConnConfig is a peer Connection configuration // ConnConfig is a peer Connection configuration
type ConnConfig struct { type ConnConfig struct {
@@ -43,7 +54,7 @@ type ConnConfig struct {
Timeout time.Duration Timeout time.Duration
ProxyConfig proxy.Config WgConfig WgConfig
UDPMux ice.UDPMux UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux UDPMuxSrflx ice.UniversalUDPMux
@@ -98,7 +109,9 @@ type Conn struct {
statusRecorder *Status statusRecorder *Status
proxy proxy.Proxy wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy
remoteModeCh chan ModeMessage remoteModeCh chan ModeMessage
meta meta meta meta
@@ -122,14 +135,19 @@ func (conn *Conn) GetConf() ConnConfig {
return conn.config return conn.config
} }
// UpdateConf updates the connection config // WgConfig returns the WireGuard config
func (conn *Conn) UpdateConf(conf ConnConfig) { func (conn *Conn) WgConfig() WgConfig {
conn.config = conf return conn.config.WgConfig
}
// UpdateStunTurn update the turn and stun addresses
func (conn *Conn) UpdateStunTurn(turnStun []*ice.URL) {
conn.config.StunTurn = turnStun
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) { func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{ return &Conn{
config: config, config: config,
mu: sync.Mutex{}, mu: sync.Mutex{},
@@ -139,6 +157,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter
remoteAnswerCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1), remoteModeCh: make(chan ModeMessage, 1),
wgProxyFactory: wgProxyFactory,
adapter: adapter, adapter: adapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
}, nil }, nil
@@ -215,12 +234,12 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
func (conn *Conn) Open() error { func (conn *Conn) Open() error {
log.Debugf("trying to connect to peer %s", conn.config.Key) log.Debugf("trying to connect to peer %s", conn.config.Key)
peerState := State{PubKey: conn.config.Key} peerState := State{
PubKey: conn.config.Key,
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0] IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
peerState.ConnStatusUpdate = time.Now() ConnStatusUpdate: time.Now(),
peerState.ConnStatus = conn.status ConnStatus: conn.status,
}
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -275,10 +294,11 @@ func (conn *Conn) Open() error {
defer conn.notifyDisconnected() defer conn.notifyDisconnected()
conn.mu.Unlock() conn.mu.Unlock()
peerState = State{PubKey: conn.config.Key} peerState = State{
PubKey: conn.config.Key,
peerState.ConnStatus = conn.status ConnStatus: conn.status,
peerState.ConnStatusUpdate = time.Now() ConnStatusUpdate: time.Now(),
}
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("erro while updating the state of peer %s,err: %v", conn.config.Key, err)
@@ -309,19 +329,12 @@ func (conn *Conn) Open() error {
remoteWgPort = remoteOfferAnswer.WgListenPort remoteWgPort = remoteOfferAnswer.WgListenPort
} }
// the ice connection has been established successfully so we are ready to start the proxy // the ice connection has been established successfully so we are ready to start the proxy
err = conn.startProxy(remoteConn, remoteWgPort) remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort)
if err != nil { if err != nil {
return err return err
} }
if conn.proxy.Type() == proxy.TypeDirectNoProxy { log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
log.Infof("directly connected to peer %s [laddr <-> raddr] [%s:%d <-> %s:%d]", conn.config.Key, host, conn.config.LocalWgPort, rhost, remoteWgPort)
} else {
log.Infof("connected to peer %s [laddr <-> raddr] [%s <-> %s]", conn.config.Key, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
}
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine) // wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select { select {
@@ -338,54 +351,60 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error { func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int) (net.Addr, error) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
var pair *ice.CandidatePair
pair, err := conn.agent.GetSelectedCandidatePair() pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
return err return nil, err
} }
peerState := State{PubKey: conn.config.Key} var endpoint net.Addr
p := conn.getProxy(pair, remoteWgPort) if isRelayCandidate(pair.Local) {
conn.proxy = p log.Debugf("setup relay connection")
err = p.Start(remoteConn) conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil { if err != nil {
return err if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
return nil, err
} }
conn.status = StatusConnected conn.status = StatusConnected
peerState.ConnStatus = conn.status peerState := State{
peerState.ConnStatusUpdate = time.Now() PubKey: conn.config.Key,
peerState.LocalIceCandidateType = pair.Local.Type().String() ConnStatus: conn.status,
peerState.RemoteIceCandidateType = pair.Remote.Type().String() ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
Direct: !isRelayCandidate(pair.Local),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true peerState.Relayed = true
} }
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState) err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err) log.Warnf("unable to save peer's state, got error: %v", err)
} }
return nil return endpoint, nil
}
// todo rename this method and the proxy package to something more appropriate
func (conn *Conn) getProxy(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
if isRelayCandidate(pair.Local) {
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
}
// To support old version's with direct mode we attempt to punch an additional role with the remote wireguard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
return proxy.NewNoProxy(conn.config.ProxyConfig)
} }
func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -414,22 +433,22 @@ func (conn *Conn) cleanup() error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
var err1, err2, err3 error
if conn.agent != nil { if conn.agent != nil {
err := conn.agent.Close() err1 = conn.agent.Close()
if err != nil { if err1 == nil {
return err conn.agent = nil
} }
conn.agent = nil
} }
if conn.proxy != nil { if conn.wgProxy != nil {
err := conn.proxy.Close() err2 = conn.wgProxy.CloseConn()
if err != nil { conn.wgProxy = nil
return err
}
conn.proxy = nil
} }
// todo: is it problem if we try to remove a peer what is never existed?
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.notifyDisconnected != nil { if conn.notifyDisconnected != nil {
conn.notifyDisconnected() conn.notifyDisconnected()
conn.notifyDisconnected = nil conn.notifyDisconnected = nil
@@ -437,10 +456,11 @@ func (conn *Conn) cleanup() error {
conn.status = StatusDisconnected conn.status = StatusDisconnected
peerState := State{PubKey: conn.config.Key} peerState := State{
peerState.ConnStatus = conn.status PubKey: conn.config.Key,
peerState.ConnStatusUpdate = time.Now() ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
@@ -449,8 +469,13 @@ func (conn *Conn) cleanup() error {
} }
log.Debugf("cleaned up connection to peer %s", conn.config.Key) log.Debugf("cleaned up connection to peer %s", conn.config.Key)
if err1 != nil {
return nil return err1
}
if err2 != nil {
return err2
}
return err3
} }
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer // SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer

View File

@@ -5,12 +5,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@@ -20,7 +19,6 @@ var connConf = ConnConfig{
StunTurn: []*ice.URL{}, StunTurn: []*ice.URL{},
InterfaceBlackList: nil, InterfaceBlackList: nil,
Timeout: time.Second, Timeout: time.Second,
ProxyConfig: proxy.Config{},
LocalWgPort: 51820, LocalWgPort: 51820,
} }
@@ -37,7 +35,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
conn, err := NewConn(connConf, nil, nil, nil) wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -48,8 +50,11 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -82,8 +87,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -115,8 +123,11 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }
@@ -142,8 +153,11 @@ func TestConn_Status(t *testing.T) {
} }
func TestConn_Close(t *testing.T) { func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil { if err != nil {
return return
} }

View File

@@ -59,6 +59,11 @@ type Status struct {
mgmAddress string mgmAddress string
signalAddress string signalAddress string
notifier *notifier notifier *notifier
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
@@ -78,11 +83,13 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
defer d.mux.Unlock() defer d.mux.Unlock()
d.offlinePeers = make([]State, len(replacement)) d.offlinePeers = make([]State, len(replacement))
copy(d.offlinePeers, replacement) copy(d.offlinePeers, replacement)
d.notifyPeerListChanged()
// todo we should set to true in case if the list changed only
d.peerListChangedForNotification = true
} }
// AddPeer adds peer to Daemon status map // AddPeer adds peer to Daemon status map
func (d *Status) AddPeer(peerPubKey string) error { func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -90,7 +97,12 @@ func (d *Status) AddPeer(peerPubKey string) error {
if ok { if ok {
return errors.New("peer already exist") return errors.New("peer already exist")
} }
d.peers[peerPubKey] = State{PubKey: peerPubKey, ConnStatus: StatusDisconnected} d.peers[peerPubKey] = State{
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
}
d.peerListChangedForNotification = true
return nil return nil
} }
@@ -112,13 +124,13 @@ func (d *Status) RemovePeer(peerPubKey string) error {
defer d.mux.Unlock() defer d.mux.Unlock()
_, ok := d.peers[peerPubKey] _, ok := d.peers[peerPubKey]
if ok { if !ok {
delete(d.peers, peerPubKey) return errors.New("no peer with to remove")
return nil
} }
d.notifyPeerListChanged() delete(d.peers, peerPubKey)
return errors.New("no peer with to remove") d.peerListChangedForNotification = true
return nil
} }
// UpdatePeerState updates peer status // UpdatePeerState updates peer status
@@ -188,10 +200,23 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
peerState.FQDN = fqdn peerState.FQDN = fqdn
d.peers[peerPubKey] = peerState d.peers[peerPubKey] = peerState
d.notifyPeerListChanged()
return nil return nil
} }
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()
if !d.peerListChangedForNotification {
d.mux.Unlock()
return
}
d.peerListChangedForNotification = false
d.mux.Unlock()
d.notifyPeerListChanged()
}
// GetPeerStateChangeNotifier returns a change notifier channel for a peer // GetPeerStateChangeNotifier returns a change notifier channel for a peer
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock() d.mux.Lock()

View File

@@ -9,13 +9,13 @@ import (
func TestAddPeer(t *testing.T) { func TestAddPeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key) err := status.AddPeer(key, "abc.netbird")
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
_, exists := status.peers[key] _, exists := status.peers[key]
assert.True(t, exists, "value was found") assert.True(t, exists, "value was found")
err = status.AddPeer(key) err = status.AddPeer(key, "abc.netbird")
assert.Error(t, err, "should return error on duplicate") assert.Error(t, err, "should return error on duplicate")
} }
@@ -23,7 +23,7 @@ func TestAddPeer(t *testing.T) {
func TestGetPeer(t *testing.T) { func TestGetPeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key) err := status.AddPeer(key, "abc.netbird")
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
peerStatus, err := status.GetPeer(key) peerStatus, err := status.GetPeer(key)

View File

@@ -0,0 +1,128 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/management/client"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -1,72 +0,0 @@
package proxy
import (
"context"
log "github.com/sirupsen/logrus"
"net"
"time"
)
// DummyProxy just sends pings to the RemoteKey peer and reads responses
type DummyProxy struct {
conn net.Conn
remote string
ctx context.Context
cancel context.CancelFunc
}
func NewDummyProxy(remote string) *DummyProxy {
p := &DummyProxy{remote: remote}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *DummyProxy) Close() error {
p.cancel()
return nil
}
func (p *DummyProxy) Start(remoteConn net.Conn) error {
p.conn = remoteConn
go func() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Read(buf)
if err != nil {
log.Errorf("error while reading RemoteKey %s proxy %v", p.remote, err)
return
}
//log.Debugf("received %s from %s", string(buf[:n]), p.remote)
}
}
}()
go func() {
for {
select {
case <-p.ctx.Done():
return
default:
_, err := p.conn.Write([]byte("hello"))
//log.Debugf("sent ping to %s", p.remote)
if err != nil {
log.Errorf("error while writing to RemoteKey %s proxy %v", p.remote, err)
return
}
time.Sleep(5 * time.Second)
}
}
}()
return nil
}
func (p *DummyProxy) Type() Type {
return TypeDummy
}

View File

@@ -1,42 +0,0 @@
package proxy
import (
log "github.com/sirupsen/logrus"
"net"
)
// NoProxy is used just to configure WireGuard without any local proxy in between.
// Used when the WireGuard interface is userspace and uses bind.ICEBind
type NoProxy struct {
config Config
}
// NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
}
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
}
func (p *NoProxy) Type() Type {
return TypeNoProxy
}

View File

@@ -1,35 +0,0 @@
package proxy
import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io"
"net"
"time"
)
const DefaultWgKeepAlive = 25 * time.Second
type Type string
const (
TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
)
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
type Proxy interface {
io.Closer
// Start creates a local remoteConn and starts proxying data from/to remoteConn
Start(remoteConn net.Conn) error
Type() Type
}

View File

@@ -1,128 +0,0 @@
package proxy
import (
"context"
log "github.com/sirupsen/logrus"
"net"
)
// WireGuardProxy proxies
type WireGuardProxy struct {
ctx context.Context
cancel context.CancelFunc
config Config
remoteConn net.Conn
localConn net.Conn
}
func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireGuardProxy{config: config}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *WireGuardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil {
return err
}
// add local proxy connection as a Wireguard peer
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
udpAddr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn
var err error
p.localConn, err = net.Dial("udp", p.config.WgListenAddr)
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return err
}
err = p.updateEndpoint()
if err != nil {
log.Errorf("error while updating Wireguard peer endpoint [%s] %v", p.config.RemoteKey, err)
return err
}
go p.proxyToRemote()
go p.proxyToLocal()
return nil
}
func (p *WireGuardProxy) Close() error {
p.cancel()
if c := p.localConn; c != nil {
err := p.localConn.Close()
if err != nil {
return err
}
}
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
return err
}
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WireGuardProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
log.Debugf("stopped proxying to remote peer %s due to closed connection", p.config.RemoteKey)
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WireGuardProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
log.Debugf("stopped proxying from remote peer %s due to closed connection", p.config.RemoteKey)
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}
func (p *WireGuardProxy) Type() Type {
return TypeWireGuard
}

View File

@@ -6,8 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -30,33 +28,13 @@ func genKey(format string, input string) string {
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager // NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) firewallManager { func NewFirewall(parentCTX context.Context) firewallManager {
ctx, cancel := context.WithCancel(parentCTX) manager, err := newNFTablesManager(parentCTX)
if err == nil {
if isIptablesSupported() { log.Debugf("nftables firewall manager will be used")
log.Debugf("iptables is supported") return manager
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
return &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
} }
log.Debugf("fallback to iptables firewall manager: %s", err)
log.Debugf("iptables is not supported, using nftables") return newIptablesManager(parentCTX)
manager := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
return manager
} }
func getInPair(pair routerPair) routerPair { func getInPair(pair routerPair) routerPair {

View File

@@ -49,6 +49,28 @@ type iptablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newIptablesManager(parentCtx context.Context) *iptablesManager {
ctx, cancel := context.WithCancel(parentCtx)
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if !isIptablesClientAvailable(ipv4Client) {
log.Infof("iptables is missing for ipv4")
ipv4Client = nil
}
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if !isIptablesClientAvailable(ipv6Client) {
log.Infof("iptables is missing for ipv6")
ipv6Client = nil
}
return &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
}
// CleanRoutingRules cleans existing iptables resources that we created by the agent // CleanRoutingRules cleans existing iptables resources that we created by the agent
func (i *iptablesManager) CleanRoutingRules() { func (i *iptablesManager) CleanRoutingRules() {
i.mux.Lock() i.mux.Lock()
@@ -61,24 +83,28 @@ func (i *iptablesManager) CleanRoutingRules() {
log.Debug("flushing tables") log.Debug("flushing tables")
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v" errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) if i.ipv4Client != nil {
if err != nil { err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
} }
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) if i.ipv6Client != nil {
if err != nil { err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) if err != nil {
} log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil { if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
} }
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
} }
log.Info("done cleaning up iptables rules") log.Info("done cleaning up iptables rules")
@@ -96,37 +122,41 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
errMSGFormat := "iptables: failed creating %s chain %s,error: %v" errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain) if i.ipv4Client != nil {
if err != nil { err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
} }
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain) if i.ipv6Client != nil {
if err != nil { err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv6Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
}
} }
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain) err := i.addJumpRules()
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
err = i.restoreRules(i.ipv6Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
}
err = i.addJumpRules()
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while creating jump rules: %v", err) return fmt.Errorf("iptables: error while creating jump rules: %v", err)
} }
@@ -140,34 +170,38 @@ func (i *iptablesManager) addJumpRules() error {
if err != nil { if err != nil {
return err return err
} }
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) if i.ipv4Client != nil {
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
if err != nil {
return err err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv4Nat)
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Nat] = rule
} }
i.rules[ipv4][ipv4Forwarding] = rule if i.ipv6Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv4Nat) rule = append(iptablesDefaultNatRule, ipv6Nat)
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil { if err != nil {
return err return err
}
i.rules[ipv6][ipv6Nat] = rule
} }
i.rules[ipv4][ipv4Nat] = rule
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv6Nat)
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Nat] = rule
return nil return nil
} }
@@ -177,35 +211,39 @@ func (i *iptablesManager) cleanJumpRules() error {
var err error var err error
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v" errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
rule, found := i.rules[ipv4][ipv4Forwarding] rule, found := i.rules[ipv4][ipv4Forwarding]
if found { if i.ipv4Client != nil {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding) if found {
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
if err != nil { err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err) if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv4][ipv4Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
}
} }
} }
rule, found = i.rules[ipv4][ipv4Nat] if i.ipv6Client == nil {
if found { rule, found = i.rules[ipv6][ipv6Forwarding]
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat) if found {
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
if err != nil { err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err) if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
}
} }
} rule, found = i.rules[ipv6][ipv6Nat]
rule, found = i.rules[ipv6][ipv6Forwarding] if found {
if found { log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding) err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) if err != nil {
if err != nil { return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err) }
}
}
rule, found = i.rules[ipv6][ipv6Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
} }
} }
return nil return nil
@@ -437,3 +475,8 @@ func getIptablesRuleType(table string) string {
} }
return ruleType return ruleType
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
ctx, cancel := context.WithCancel(context.TODO()) manager := newIptablesManager(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
@@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4") require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist") require.True(t, exists, "forwarding rule should exist")
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting rule should exist")
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6") require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist") require.True(t, exists, "forwarding rule should exist")
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting rule should exist")
@@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
forward4RuleKey := genKey(forwardingFormat, pair.ID) forward4RuleKey := genKey(forwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination) forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := genKey(natFormat, pair.ID) nat4RuleKey := genKey(natFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination) nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
pair = routerPair{ pair = routerPair{
@@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
forward6RuleKey := genKey(forwardingFormat, pair.ID) forward6RuleKey := genKey(forwardingFormat, pair.ID)
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination) forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat6RuleKey := genKey(natFormat, pair.ID) nat6RuleKey := genKey(natFormat, pair.ID)
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination) nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4) delete(manager.rules, ipv4)

View File

@@ -19,6 +19,9 @@ const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd" nftablesRoutingForwardingChain = "netbird-rt-fwd"
nftablesRoutingNatChain = "netbird-rt-nat" nftablesRoutingNatChain = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
) )
// constants needed to create nftable rules // constants needed to create nftable rules
@@ -71,14 +74,41 @@ var (
) )
type nftablesManager struct { type nftablesManager struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
conn *nftables.Conn conn *nftables.Conn
tableIPv4 *nftables.Table tableIPv4 *nftables.Table
tableIPv6 *nftables.Table tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
mux sync.Mutex filterTable *nftables.Table
defaultForwardRules []*nftables.Rule
mux sync.Mutex
}
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2),
}
err := mgr.isSupported()
if err != nil {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
} }
// CleanRoutingRules cleans existing nftables rules from the system // CleanRoutingRules cleans existing nftables rules from the system
@@ -90,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
n.conn.FlushTable(n.tableIPv6) n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4) n.conn.FlushTable(n.tableIPv4)
} }
if n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete forward rule: %s", err)
}
}
log.Debugf("flushing tables result in: %v error", n.conn.Flush()) log.Debugf("flushing tables result in: %v error", n.conn.Flush())
} }
@@ -222,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil return nil
} }
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil {
return nil
}
err := n.refreshDefaultForwardRule()
if err != nil {
return err
}
for i, r := range n.defaultForwardRules {
err = n.conn.DelRule(r)
if err != nil {
log.Errorf("failed to delete forward rule (%d): %s", i, err)
}
n.defaultForwardRules[i] = nil
}
return nil
}
func (n *nftablesManager) refreshDefaultForwardRule() error {
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
if err != nil {
return fmt.Errorf("unable to list rules in forward chain: %s", err)
}
found := false
for i, r := range n.defaultForwardRules {
for _, rule := range rules {
if string(rule.UserData) == string(r.UserData) {
n.defaultForwardRules[i] = rule
found = true
break
}
}
}
if !found {
return fmt.Errorf("unable to find forward accept rule")
}
return nil
}
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
src := generateCIDRMatcherExpressions("source", sourceNetwork)
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{
Kind: expr.VerdictAccept,
})...)
r := &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
n.defaultForwardRules[0] = n.conn.AddRule(r)
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{
Kind: expr.VerdictAccept,
})...)
r = &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
n.defaultForwardRules[1] = n.conn.AddRule(r)
return nil
}
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled // checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
_, foundIPv4 := n.rules[ipv4Forwarding] _, foundIPv4 := n.rules[ipv4Forwarding]
@@ -275,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
} }
} }
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
err = n.acceptForwardRule(pair.source)
if err != nil {
log.Errorf("unable to create default forward rule: %s", err)
}
log.Debugf("default accept forward rule added")
}
err = n.conn.Flush() err = n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
@@ -355,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err return err
} }
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delte default fwd rule: %s", err)
}
}
err = n.conn.Flush() err = n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
@@ -386,6 +544,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
return nil return nil
} }
func (n *nftablesManager) isSupported() error {
_, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables is not supported: %s", err)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@@ -14,21 +14,16 @@ import (
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) manager, err := newNFTablesManager(context.TODO())
if err != nil {
manager := &nftablesManager{ t.Fatalf("failed to create nftables manager: %s", err)
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
} }
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
@@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
for _, testCase := range insertRuleTestCases { for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) manager, err := newNFTablesManager(context.TODO())
if err != nil {
manager := &nftablesManager{ t.Fatalf("failed to create nftables manager: %s", err)
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
} }
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair) err = manager.InsertRoutingRules(testCase.inputPair)
@@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
for _, testCase := range removeRuleTestCases { for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) manager, err := newNFTablesManager(context.TODO())
if err != nil {
manager := &nftablesManager{ t.Fatalf("failed to create nftables manager: %s", err)
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
} }
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4 table := manager.tableIPv4

View File

@@ -0,0 +1,82 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package routemanager
import (
"fmt"
"net"
"net/netip"
"syscall"
"golang.org/x/net/route"
)
// selected BSD Route flags.
const (
RTF_UP = 0x1
RTF_GATEWAY = 0x2
RTF_HOST = 0x4
RTF_REJECT = 0x8
RTF_DYNAMIC = 0x10
RTF_MODIFIED = 0x20
RTF_STATIC = 0x800
RTF_BLACKHOLE = 0x1000
RTF_LOCAL = 0x200000
RTF_BROADCAST = 0x400000
RTF_MULTICAST = 0x800000
)
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return false, err
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
if err != nil {
return false, err
}
for _, msg := range msgs {
m := msg.(*route.RouteMessage)
if m.Version < 3 || m.Version > 5 {
return false, fmt.Errorf("unexpected RIB message version: %d", m.Version)
}
if m.Type != 4 /* RTM_GET */ {
return true, fmt.Errorf("unexpected RIB message type: %d", m.Type)
}
if m.Flags&RTF_UP == 0 ||
m.Flags&(RTF_REJECT|RTF_BLACKHOLE) != 0 {
continue
}
dst, err := toIPAddr(m.Addrs[0])
if err != nil {
return true, fmt.Errorf("unexpected RIB destination: %v", err)
}
mask, _ := toIPAddr(m.Addrs[2])
cidr, _ := net.IPMask(mask.To4()).Size()
if dst.String() == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil
}
}
return false, nil
}
func toIPAddr(a route.Addr) (net.IP, error) {
switch t := a.(type) {
case *route.Inet4Addr:
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return ip, nil
case *route.Inet6Addr:
ip := make(net.IP, net.IPv6len)
copy(ip, t.IP[:])
return ip, nil
default:
return net.IP{}, fmt.Errorf("unknown family: %v", t)
}
}

View File

@@ -6,10 +6,28 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"syscall"
"unsafe"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
type routeInfoInMemory struct {
Family byte
DstLen byte
SrcLen byte
TOS byte
Table byte
Protocol byte
Scope byte
Type byte
Flags uint32
}
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
func addToRouteTable(prefix netip.Prefix, addr string) error { func addToRouteTable(prefix netip.Prefix, addr string) error {
@@ -61,6 +79,45 @@ func removeFromRouteTable(prefix netip.Prefix) error {
return nil return nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
if err != nil {
return true, err
}
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return true, err
}
loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
break loop
case syscall.RTM_NEWROUTE:
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
if err != nil {
return true, err
}
if rt.Family != syscall.AF_INET {
continue loop
}
for _, attr := range attrs {
if attr.Attr.Type == syscall.RTA_DST {
ip := net.IP(attr.Value)
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
cidr, _ := mask.Size()
if ip.String() == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil
}
}
}
}
}
return false, nil
}
func enableIPForwarding() error { func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath) bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil { if err != nil {

View File

@@ -14,19 +14,26 @@ import (
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
}
prefixGateway, err := getExistingRIBRouteGateway(prefix)
if err != nil && err != errRouteNotFound { if err != nil && err != errRouteNotFound {
return err return err
} }
if prefixGateway != nil && !prefixGateway.Equal(gateway) { gatewayIP := netip.MustParseAddr(defaultGateway.String())
log.Warnf("skipping adding a new route for network %s because it already exists and is pointing to the non default gateway: %s", prefix, prefixGateway) if prefix.Contains(gatewayIP) {
log.Warnf("skipping adding a new route for network %s because it overlaps with the default gateway: %s", prefix, gatewayIP)
return nil return nil
} }
ok, err := existsInRouteTable(prefix)
if err != nil {
return err
}
if ok {
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
return nil
}
return addToRouteTable(prefix, addr) return addToRouteTable(prefix, addr)
} }
@@ -53,6 +60,7 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
log.Errorf("getting routes returned an error: %v", err) log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound return nil, errRouteNotFound
} }
if gateway == nil { if gateway == nil {
return preferredSrc, nil return preferredSrc, nil
} }

View File

@@ -1,13 +1,19 @@
package routemanager package routemanager
import ( import (
"bytes"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require"
"net" "net"
"net/netip" "net/netip"
"os"
"strings"
"testing" "testing"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/iface"
) )
func TestAddRemoveRoutes(t *testing.T) { func TestAddRemoveRoutes(t *testing.T) {
@@ -114,3 +120,98 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
} }
} }
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
fmt.Println("defaultGateway: ", defaultGateway)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
MockAddr := wgInterface.Address().IP.String()
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
fmt.Println("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}

View File

@@ -0,0 +1,37 @@
//go:build windows
// +build windows
package routemanager
import (
"net"
"net/netip"
"github.com/yusufpapurcu/wmi"
)
type Win32_IP4RouteTable struct {
Destination string
Mask string
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
var routes []Win32_IP4RouteTable
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
err := wmi.Query(query, &routes)
if err != nil {
return true, err
}
for _, route := range routes {
ip := net.ParseIP(route.Mask)
ip = ip.To4()
mask := net.IPv4Mask(ip[0], ip[1], ip[2], ip[3])
cidr, _ := mask.Size()
if route.Destination == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,8 @@
package templates
import (
_ "embed"
)
//go:embed pkce-auth-msg.html
var PKCEAuthMsgTmpl string

View File

@@ -0,0 +1,87 @@
<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<style>
body {
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
background: #f7f8f9;
font-family: sans-serif, Arial, Tahoma;
}
.container {
width: 100%;
background: white;
border: 1px solid #e8e9ea;
text-align: center;
padding: 20px;
padding-bottom: 50px;
max-width: 550px;
margin: 0 10px;
}
.logo {
height: 80px;
border-bottom: 1px solid #e8e9ea;
display: flex;
justify-content: center;
align-items: center;
}
.logo img {
width: 130px;
}
.content {
font-size: 13px;
color: #525252;
line-height: 18px;
padding: 10px 0;
}
.content div {
font-size: 18px;
line-height: normal;
margin-bottom: 5px;
color: black;
}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<img src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
</div>
<br>
{{ if .Error }}
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="red" stroke-width="3"/>
<path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/>
</svg>
<div class="content">
<div>
Login failed
</div>
{{ .Error }}.
</div>
{{ else }}
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="#5cb85c" stroke-width="3"/>
<path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/>
</svg>
<div class="content">
<div>
Login successful
</div>
Your device is now registered and logged in to NetBird.
<br>
You can now close this window.
</div>
{{ end }}
</div>
</body>
</html>

View File

@@ -0,0 +1,120 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
XdpProgFunc *ebpf.ProgramSpec `ebpf:"xdp_prog_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
XdpPortMap *ebpf.MapSpec `ebpf:"xdp_port_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
XdpPortMap *ebpf.Map `ebpf:"xdp_port_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.XdpPortMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
XdpProgFunc *ebpf.Program `ebpf:"xdp_prog_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.XdpProgFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,120 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
package ebpf
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
XdpProgFunc *ebpf.ProgramSpec `ebpf:"xdp_prog_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
XdpPortMap *ebpf.MapSpec `ebpf:"xdp_port_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
XdpPortMap *ebpf.Map `ebpf:"xdp_port_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.XdpPortMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
XdpProgFunc *ebpf.Program `ebpf:"xdp_prog_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.XdpProgFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,84 @@
//go:build linux && !android
package ebpf
import (
_ "embed"
"net"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
)
const (
mapKeyProxyPort uint32 = 0
mapKeyWgPort uint32 = 1
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/portreplace.c --
// EBPF is a wrapper for eBPF program
type EBPF struct {
link link.Link
}
// NewEBPF create new EBPF instance
func NewEBPF() *EBPF {
return &EBPF{}
}
// Load load ebpf program
func (l *EBPF) Load(proxyPort, wgPort int) error {
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
ifce, err := net.InterfaceByName("lo")
if err != nil {
return err
}
// Load pre-compiled programs into the kernel.
objs := bpfObjects{}
err = loadBpfObjects(&objs, nil)
if err != nil {
return err
}
defer func() {
_ = objs.Close()
}()
err = objs.XdpPortMap.Put(mapKeyProxyPort, uint16(proxyPort))
if err != nil {
return err
}
err = objs.XdpPortMap.Put(mapKeyWgPort, uint16(wgPort))
if err != nil {
return err
}
defer func() {
_ = objs.XdpPortMap.Close()
}()
l.link, err = link.AttachXDP(link.XDPOptions{
Program: objs.XdpProgFunc,
Interface: ifce.Index,
})
if err != nil {
return err
}
return err
}
// Free free ebpf program
func (l *EBPF) Free() error {
if l.link != nil {
return l.link.Close()
}
return nil
}

View File

@@ -0,0 +1,18 @@
//go:build linux
package ebpf
import (
"testing"
)
func Test_newEBPF(t *testing.T) {
ebpf := NewEBPF()
err := ebpf.Load(1234, 51892)
defer func() {
_ = ebpf.Free()
}()
if err != nil {
t.Errorf("%s", err)
}
}

View File

@@ -0,0 +1,90 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u32 map_key_proxy_port = 0;
const __u32 map_key_wg_port = 1;
struct bpf_map_def SEC("maps") xdp_port_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__u16 proxy_port = 0;
__u16 wg_port = 0;
bool read_port_settings() {
__u16 *value;
value = bpf_map_lookup_elem(&xdp_port_map, &map_key_proxy_port);
if(!value) {
return false;
}
proxy_port = *value;
value = bpf_map_lookup_elem(&xdp_port_map, &map_key_wg_port);
if(!value) {
return false;
}
wg_port = *value;
return true;
}
SEC("xdp")
int xdp_prog_func(struct xdp_md *ctx) {
if(proxy_port == 0 || wg_port == 0) {
if(!read_port_settings()){
return XDP_PASS;
}
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
// 2130706433 = 127.0.0.1
if (ip->daddr != htonl(2130706433)) {
return XDP_PASS;
}
if (udp->source != htons(wg_port)){
return XDP_PASS;
}
__be16 new_src_port = udp->dest;
__be16 new_dst_port = htons(proxy_port);
udp->dest = new_dst_port;
udp->source = new_src_port;
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@@ -0,0 +1,20 @@
package wgproxy
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.CloseConn()
}
return nil
}

View File

@@ -0,0 +1,21 @@
//go:build !android
package wgproxy
import (
log "github.com/sirupsen/logrus"
)
func NewFactory(wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
ebpfProxy := NewWGEBPFProxy(wgPort)
err := ebpfProxy.Listen()
if err != nil {
log.Errorf("failed to initialize ebpf proxy: %s", err)
return f
}
f.ebpfProxy = ebpfProxy
return f
}

View File

@@ -0,0 +1,7 @@
//go:build !linux || android
package wgproxy
func NewFactory(wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}

View File

@@ -0,0 +1,32 @@
package wgproxy
import (
"fmt"
"net"
)
const (
portRangeStart = 3128
portRangeEnd = 3228
)
type portLookup struct {
}
func (pl portLookup) searchFreePort() (int, error) {
for i := portRangeStart; i <= portRangeEnd; i++ {
if pl.tryToBind(i) == nil {
return i, nil
}
}
return 0, fmt.Errorf("failed to bind free port for eBPF proxy")
}
func (pl portLookup) tryToBind(port int) error {
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
_ = l.Close()
return nil
}

View File

@@ -0,0 +1,42 @@
package wgproxy
import (
"fmt"
"net"
"testing"
)
func Test_portLookup_searchFreePort(t *testing.T) {
pl := portLookup{}
_, err := pl.searchFreePort()
if err != nil {
t.Fatal(err)
}
}
func Test_portLookup_on_allocated(t *testing.T) {
pl := portLookup{}
allocatedPort, err := allocatePort(portRangeStart)
if err != nil {
t.Fatal(err)
}
defer allocatedPort.Close()
fp, err := pl.searchFreePort()
if err != nil {
t.Fatal(err)
}
if fp != (portRangeStart + 1) {
t.Errorf("invalid free port, expected: %d, got: %d", portRangeStart+1, fp)
}
}
func allocatePort(port int) (net.PacketConn, error) {
c, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
return c, err
}

View File

@@ -0,0 +1,12 @@
package wgproxy
import (
"net"
)
// Proxy is a transfer layer between the Turn connection and the WireGuard
type Proxy interface {
AddTurnConn(urnConn net.Conn) (net.Addr, error)
CloseConn() error
Free() error
}

View File

@@ -0,0 +1,252 @@
//go:build linux && !android
package wgproxy
import (
"fmt"
"io"
"net"
"os"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
ebpf2 "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpf *ebpf2.EBPF
lastUsedPort uint16
localWGListenPort int
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn *net.UDPConn
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpf: ebpf2.NewEBPF(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
return wgProxy
}
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.rawConn, err = p.prepareSenderRawSocket()
if err != nil {
return err
}
err = p.ebpf.Load(wgPorxyPort, p.localWGListenPort)
if err != nil {
return err
}
addr := net.UDPAddr{
Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"),
}
p.conn, err = net.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
if err != nil {
log.Errorf("failed to close the wgproxy: %s", cErr)
}
return err
}
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
return nil
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
func (p *WGEBPFProxy) Free() error {
var err1, err2, err3 error
if p.conn != nil {
err1 = p.conn.Close()
}
err2 = p.ebpf.Free()
if p.rawConn != nil {
err3 = p.rawConn.Close()
}
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
buf := make([]byte, 1500)
for {
n, err := remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
p.removeTurnConn(endpointPort)
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return
}
conn, ok := p.turnConnStore[uint16(addr.Port)]
if !ok {
log.Errorf("turn conn not found by port: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
}
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
np, err := p.nextFreePort()
if err != nil {
return np, err
}
p.turnConnStore[np] = turnConn
return np, nil
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Tracef("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
if len(p.turnConnStore) == 65535 {
return 0, fmt.Errorf("reached maximum turn connection numbers")
}
generatePort:
if p.lastUsedPort == 65535 {
p.lastUsedPort = 1
} else {
p.lastUsedPort++
}
if _, ok := p.turnConnStore[p.lastUsedPort]; ok {
goto generatePort
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, err
}
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, err
}
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, err
}
return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)))
}
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localhost,
SrcIP: localhost,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return err
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return err
}
_, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost})
return err
}

View File

@@ -0,0 +1,56 @@
//go:build linux && !android
package wgproxy
import (
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
t.Errorf("invalid initial port: %d", wgProxy.lastUsedPort)
}
numOfConns := 10
for i := 0; i < numOfConns; i++ {
p, _ = wgProxy.storeTurnConn(nil)
}
if p != uint16(numOfConns)+1 {
t.Errorf("invalid last used port: %d, expected: %d", p, numOfConns+1)
}
if len(wgProxy.turnConnStore) != numOfConns+1 {
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), numOfConns+1)
}
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
p, _ := wgProxy.storeTurnConn(nil)
if len(wgProxy.turnConnStore) != 2 {
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), 2)
}
if p != 2 {
t.Errorf("invalid last used port: %d, expected: %d", p, 2)
}
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)
}
_, err := wgProxy.storeTurnConn(nil)
if err == nil {
t.Errorf("invalid turn conn store calculation")
}
}

View File

@@ -0,0 +1,105 @@
package wgproxy
import (
"context"
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn
var err error
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
continue
}
}
}
}

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/auth"
"sync" "sync"
"time" "time"
@@ -38,8 +39,8 @@ type Server struct {
type oauthAuthFlow struct { type oauthAuthFlow struct {
expiresAt time.Time expiresAt time.Time
client internal.OAuthClient flow auth.OAuthFlow
info internal.DeviceAuthInfo info auth.AuthFlowInfo
waitCancel context.CancelFunc waitCancel context.CancelFunc
} }
@@ -102,7 +103,7 @@ func (s *Server) Start() error {
} }
go func() { go func() {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil { if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
log.Errorf("init connections: %v", err) log.Errorf("init connections: %v", err)
} }
}() }()
@@ -206,28 +207,15 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
s, ok := gstatus.FromError(err) return nil, err
if ok && s.Code() == codes.NotFound {
return nil, gstatus.Errorf(codes.NotFound, "no SSO provider returned from management. "+
"If you are using hosting Netbird see documentation at "+
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, gstatus.Errorf(codes.Unimplemented, "the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
} else {
log.Errorf("getting device authorization flow info failed with error: %v", err)
return nil, err
}
} }
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous device flow info") log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{ return &proto.LoginResponse{
NeedsSSOLogin: true, NeedsSSOLogin: true,
VerificationURI: s.oauthAuthFlow.info.VerificationURI, VerificationURI: s.oauthAuthFlow.info.VerificationURI,
@@ -242,25 +230,25 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
} }
} }
deviceAuthInfo, err := hostedClient.RequestDeviceCode(context.TODO()) authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
if err != nil { if err != nil {
log.Errorf("getting a request device code failed: %v", err) log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err return nil, err
} }
s.mutex.Lock() s.mutex.Lock()
s.oauthAuthFlow.client = hostedClient s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = deviceAuthInfo s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(deviceAuthInfo.ExpiresIn) * time.Second) s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock() s.mutex.Unlock()
state.Set(internal.StatusNeedsLogin) state.Set(internal.StatusNeedsLogin)
return &proto.LoginResponse{ return &proto.LoginResponse{
NeedsSSOLogin: true, NeedsSSOLogin: true,
VerificationURI: deviceAuthInfo.VerificationURI, VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: deviceAuthInfo.VerificationURIComplete, VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: deviceAuthInfo.UserCode, UserCode: authInfo.UserCode,
}, nil }, nil
} }
@@ -289,8 +277,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.actCancel = cancel s.actCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
if s.oauthAuthFlow.client == nil { if s.oauthAuthFlow.flow == nil {
return nil, gstatus.Errorf(codes.Internal, "oauth client is not initialized") return nil, gstatus.Errorf(codes.Internal, "oauth flow is not initialized")
} }
state := internal.CtxGetState(ctx) state := internal.CtxGetState(ctx)
@@ -304,10 +292,10 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
s.mutex.Lock() s.mutex.Lock()
deviceAuthInfo := s.oauthAuthFlow.info flowInfo := s.oauthAuthFlow.info
s.mutex.Unlock() s.mutex.Unlock()
if deviceAuthInfo.UserCode != msg.UserCode { if flowInfo.UserCode != msg.UserCode {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid") return nil, gstatus.Errorf(codes.InvalidArgument, "sso user code is invalid")
} }
@@ -324,7 +312,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel = cancel s.oauthAuthFlow.waitCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
tokenInfo, err := s.oauthAuthFlow.client.WaitToken(waitCTX, deviceAuthInfo) tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
if err == context.Canceled { if err == context.Canceled {
return nil, nil return nil, nil
@@ -391,7 +379,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
go func() { go func() {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil { if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
log.Errorf("run client connection: %v", err) log.Errorf("run client connection: %v", err)
return return
} }

View File

@@ -2,9 +2,6 @@ package ssh
import ( import (
"fmt" "fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io" "io"
"net" "net"
"os" "os"
@@ -13,11 +10,22 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
) )
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 44338 const DefaultSSHPort = 44338
// TerminalTimeout is the timeout for terminal session to be ready
const TerminalTimeout = 10 * time.Second
// TerminalBackoffDelay is the delay between terminal session readiness checks
const TerminalBackoffDelay = 500 * time.Millisecond
// DefaultSSHServer is a function that creates DefaultServer // DefaultSSHServer is a function that creates DefaultServer
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return newDefaultServer(hostKeyPEM, addr) return newDefaultServer(hostKeyPEM, addr)
@@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
} }
}() }()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User()) localUser, err := userNameLookup(session.User())
if err != nil { if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
@@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
} }
} }
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd) file, err := pty.Start(cmd)
if err != nil { if err != nil {
log.Errorf("failed starting SSH server %v", err) log.Errorf("failed starting SSH server %v", err)
@@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
return return
} }
} }
log.Debugf("SSH session ended")
} }
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
@@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
// stdin // stdin
_, err := io.Copy(file, session) _, err := io.Copy(file, session)
if err != nil { if err != nil {
_ = session.Exit(1)
return return
} }
}() }()
go func() { // AWS Linux 2 machines need some time to open the terminal so we need to wait for it
// stdout timer := time.NewTimer(TerminalTimeout)
_, err := io.Copy(session, file) for {
if err != nil { select {
case <-timer.C:
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
_ = session.Exit(1)
return return
default:
// stdout
writtenBytes, err := io.Copy(session, file)
if err != nil && writtenBytes != 0 {
_ = session.Exit(0)
return
}
time.Sleep(TerminalBackoffDelay)
} }
}() }
} }
// Start starts SSH server. Blocking // Start starts SSH server. Blocking

25
go.mod
View File

@@ -17,12 +17,12 @@ require (
github.com/spf13/cobra v1.6.1 github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0 golang.org/x/crypto v0.9.0
golang.org/x/sys v0.8.0 golang.org/x/sys v0.8.0
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.52.3 google.golang.org/grpc v1.55.0
google.golang.org/protobuf v1.30.0 google.golang.org/protobuf v1.30.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -30,6 +30,7 @@ require (
require ( require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.1.4
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.10.0
github.com/coreos/go-iptables v0.6.0 github.com/coreos/go-iptables v0.6.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1 github.com/eko/gocache/v3 v3.1.1
@@ -48,6 +49,7 @@ require (
github.com/mdlayher/socket v0.4.0 github.com/mdlayher/socket v0.4.0
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2 github.com/pion/logging v0.2.2
@@ -57,6 +59,7 @@ require (
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
github.com/yusufpapurcu/wmi v1.2.3
go.opentelemetry.io/otel v1.11.1 go.opentelemetry.io/otel v1.11.1
go.opentelemetry.io/otel/exporters/prometheus v0.33.0 go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0 go.opentelemetry.io/otel/metric v0.33.0
@@ -65,18 +68,22 @@ require (
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.10.0 golang.org/x/net v0.10.0
golang.org/x/sync v0.1.0 golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.2.0
golang.org/x/term v0.8.0 golang.org/x/term v0.8.0
google.golang.org/api v0.126.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
cloud.google.com/go/compute v1.19.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/BurntSushi/toml v1.2.1 // indirect github.com/BurntSushi/toml v1.2.1 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
@@ -92,9 +99,14 @@ require (
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/s2a-go v0.1.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/native v1.0.0 // indirect github.com/josharian/native v1.0.0 // indirect
@@ -114,23 +126,22 @@ require (
github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/yuin/goldmark v1.4.13 // indirect github.com/yuin/goldmark v1.4.13 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/image v0.5.0 // indirect golang.org/x/image v0.5.0 // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/oauth2 v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect golang.org/x/text v0.9.0 // indirect
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

Some files were not shown because too many files have changed in this diff Show More