mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 23:36:39 +00:00
Compare commits
83 Commits
feature/fl
...
v0.41.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7839d2c169 | ||
|
|
b9f82e2f8a | ||
|
|
fd2a21c65d | ||
|
|
82d982b0ab | ||
|
|
9e24fe7701 | ||
|
|
e470701b80 | ||
|
|
e3ce026355 | ||
|
|
5ea2806663 | ||
|
|
d6b0673580 | ||
|
|
14913cfa7a | ||
|
|
03f600b576 | ||
|
|
192c97aa63 | ||
|
|
4db78db49a | ||
|
|
87e600a4f3 | ||
|
|
6162aeb82d | ||
|
|
1ba1e092ce | ||
|
|
86dbb4ee4f | ||
|
|
4af177215f | ||
|
|
df9c1b9883 | ||
|
|
5752bb78f2 | ||
|
|
fbd783ad58 | ||
|
|
80702b9323 | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 | ||
|
|
21464ac770 | ||
|
|
ed5647028a | ||
|
|
29a6e5be71 | ||
|
|
6124e3b937 | ||
|
|
50f5cc48cd | ||
|
|
101cce27f2 | ||
|
|
a4f04f5570 | ||
|
|
fceb3ca392 | ||
|
|
34d86c5ab8 | ||
|
|
9cbcf7531f | ||
|
|
bd8f0c1ef3 | ||
|
|
051a5a4adc | ||
|
|
8b4c0c58e4 | ||
|
|
99b41543b8 | ||
|
|
2bbe0f3f09 | ||
|
|
9325fb7990 | ||
|
|
f081435a56 | ||
|
|
b62a1b56ce | ||
|
|
8d7c92c661 | ||
|
|
d9d051cb1e | ||
|
|
cb318b7ef4 | ||
|
|
8f0aa8352a | ||
|
|
c02e236196 | ||
|
|
f51e0b59bd | ||
|
|
32ec42a667 | ||
|
|
9929daf6ce | ||
|
|
939419a0ea | ||
|
|
919fe94fd5 | ||
|
|
df71cb4690 | ||
|
|
4508c61728 | ||
|
|
0ef476b014 | ||
|
|
6f82e96d6a | ||
|
|
a2faae5d62 | ||
|
|
4a3cbcd38a | ||
|
|
c2980bc8cf | ||
|
|
67ae871ce4 | ||
|
|
39ff5e833a | ||
|
|
cd9eff5331 | ||
|
|
80ceb80197 | ||
|
|
636a0e2475 | ||
|
|
e66e329bf6 | ||
|
|
aaa23beeec | ||
|
|
6bef474e9e | ||
|
|
81040ff80a | ||
|
|
c73481aee4 | ||
|
|
fc1da94520 | ||
|
|
ae6b61301c | ||
|
|
a444e551b3 | ||
|
|
53b9a2002f | ||
|
|
4b76d93cec | ||
|
|
062d1ec76f | ||
|
|
c111675dd8 | ||
|
|
60ffe0dc87 | ||
|
|
bcc5824980 | ||
|
|
af5796de1c | ||
|
|
9d604b7e66 | ||
|
|
82c12cc8ae |
27
.git-branches.toml
Normal file
27
.git-branches.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
# More info around this file at https://www.git-town.com/configuration-file
|
||||
|
||||
[branches]
|
||||
main = "main"
|
||||
perennials = []
|
||||
perennial-regex = ""
|
||||
|
||||
[create]
|
||||
new-branch-type = "feature"
|
||||
push-new-branches = false
|
||||
|
||||
[hosting]
|
||||
dev-remote = "origin"
|
||||
# platform = ""
|
||||
# origin-hostname = ""
|
||||
|
||||
[ship]
|
||||
delete-tracking-branch = false
|
||||
strategy = "squash-merge"
|
||||
|
||||
[sync]
|
||||
feature-strategy = "merge"
|
||||
perennial-strategy = "rebase"
|
||||
prototype-strategy = "merge"
|
||||
push-hook = true
|
||||
tags = true
|
||||
upstream = false
|
||||
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -dA output:**
|
||||
**Is any other VPN software installed?**
|
||||
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
If yes, which one?
|
||||
|
||||
**Do you face any (non-mobile) client issues?**
|
||||
**Debug output**
|
||||
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
To help us resolve the problem, please attach the following debug output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
As well as the file created by
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
We advise reviewing the anonymized output for any remaining personal information.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -2,6 +2,10 @@
|
||||
|
||||
## Issue ticket number and link
|
||||
|
||||
## Stack
|
||||
|
||||
<!-- branch-stack -->
|
||||
|
||||
### Checklist
|
||||
- [ ] Is it a bug fix
|
||||
- [ ] Is a typo/documentation fix
|
||||
|
||||
21
.github/workflows/git-town.yml
vendored
Normal file
21
.github/workflows/git-town.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Git Town
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
name: Display the branch stack
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
10
.github/workflows/golang-test-freebsd.yml
vendored
10
.github/workflows/golang-test-freebsd.yml
vendored
@@ -22,14 +22,20 @@ jobs:
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y go pkgconf xorg
|
||||
pkg install -y curl pkgconf xorg
|
||||
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
|
||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -258,7 +258,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -325,8 +325,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -392,7 +392,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -461,7 +461,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
@@ -87,25 +87,25 @@ jobs:
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 3
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 3
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
retention-days: 3
|
||||
retention-days: 7
|
||||
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -150,7 +150,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
|
||||
@@ -178,6 +178,7 @@ jobs:
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -53,9 +53,9 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird.png
|
||||
- src: client/ui/assets/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -72,9 +72,9 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird.png
|
||||
- src: client/ui/assets/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
|
||||
@@ -1,148 +1,64 @@
|
||||
# Contributor License Agreement
|
||||
## Contributor License Agreement
|
||||
|
||||
We are incredibly thankful for the contributions we receive from the community.
|
||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||
|
||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables NetBird to safely commercialize our products
|
||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||
ability to use the project in their own projects or businesses, to republish modified
|
||||
source, or to completely fork the project.
|
||||
|
||||
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
|
||||
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
|
||||
|
||||
# Human-friendly summary
|
||||
|
||||
This is a human-readable summary of (and not a substitute for) the full agreement (below).
|
||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||
carefully review all the terms of the actual CLA before agreeing.
|
||||
|
||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||
in commercial products.
|
||||
</li>
|
||||
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||
license to use that patent including within commercial products. You also agree that you
|
||||
have permission to grant this license.
|
||||
</li>
|
||||
|
||||
<li>No Warranty or Support Obligations.
|
||||
By making a contribution, you are not obligating yourself to provide support for the
|
||||
contribution, and you are not taking on any warranty obligations or providing any
|
||||
assurances about how it will perform.
|
||||
</li>
|
||||
|
||||
The CLA does not change the terms of the standard open source license used by our software
|
||||
such as BSD-3 or MIT.
|
||||
You are still free to use our projects within your own projects or businesses, republish
|
||||
modified source, and more.
|
||||
Please reference the appropriate license for the project you're contributing to to learn
|
||||
more.
|
||||
|
||||
# Why require a CLA?
|
||||
|
||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||
products.
|
||||
|
||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
|
||||
under the project's respective open source license, such as BSD-3.
|
||||
|
||||
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
|
||||
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
|
||||
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
|
||||
|
||||
# Signing the CLA
|
||||
|
||||
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
|
||||
to sign the CLA if you haven't already.
|
||||
|
||||
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
|
||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||
require you to sign again.
|
||||
|
||||
# Legal Terms and Agreement
|
||||
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||
your own Contributions for any other purpose.
|
||||
|
||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||
You reserve all right, title, and interest in and to Your Contributions.
|
||||
|
||||
1. Definitions.
|
||||
|
||||
```
|
||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||
entities that control, are controlled by, or are under common control with that entity are considered
|
||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
```
|
||||
```
|
||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
```
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||
of the terms and conditions outlined below. The Contributor further represents that they are authorized to
|
||||
complete this process as described herein.
|
||||
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
|
||||
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
|
||||
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
|
||||
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
|
||||
## 1 Preamble
|
||||
In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird
|
||||
must have a contributor license agreement on file to be signed by each Contributor, containing the license
|
||||
terms below. This license serves as protection for both the Contributor as well as NetBird and its software users;
|
||||
it does not change Contributor’s rights to use his/her own Contributions for any other purpose.
|
||||
|
||||
## 2 Definitions
|
||||
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
|
||||
|
||||
4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to
|
||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||
with NetBird.
|
||||
2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work.
|
||||
|
||||
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||
others). You represent that Your Contribution submissions include complete details of any third-party license or
|
||||
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
|
||||
and which are associated with any part of Your Contributions.
|
||||
2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution".
|
||||
|
||||
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
|
||||
|
||||
6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
|
||||
You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in
|
||||
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
|
||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
## 3 Licenses
|
||||
3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees.
|
||||
|
||||
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributor‘s Contribution(s) alone or by combination of Contributor’s Contribution(s) with the Work to which such Contribution(s) was Submitted.
|
||||
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
3.3 NetBird hereby accepts such licenses.
|
||||
|
||||
## 4 Contributor’s Representations
|
||||
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributor’s employer has IP Rights to Contributor’s Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
|
||||
|
||||
4.2 Contributor represents that any Contribution is his/her original creation.
|
||||
|
||||
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
|
||||
|
||||
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
|
||||
|
||||
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
|
||||
|
||||
## 5 Information obligation
|
||||
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
|
||||
|
||||
## 6 Submission of Third-Party works
|
||||
Should Contributor wish to submit work that is not Contributor’s original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
|
||||
## 7 No Consideration
|
||||
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
|
||||
|
||||
## 8 Final Provisions
|
||||
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
|
||||
|
||||
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
|
||||
|
||||
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
|
||||
|
||||
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
|
||||
|
||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
|
||||
28
README.md
28
README.md
@@ -12,7 +12,7 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
@@ -29,13 +29,13 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
||||
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
|
||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||
New: NetBird Kubernetes Operator
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -57,16 +57,16 @@
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ type Anonymizer struct {
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 192.51.100.0, 100::
|
||||
// 198.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
}
|
||||
|
||||
|
||||
98
client/cmd/forwarding_rules.go
Normal file
98
client/cmd/forwarding_rules.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var forwardingRulesCmd = &cobra.Command{
|
||||
Use: "forwarding",
|
||||
Short: "List forwarding rules",
|
||||
Long: `Commands to list forwarding rules.`,
|
||||
}
|
||||
|
||||
var forwardingRulesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List forwarding rules",
|
||||
Example: " netbird forwarding list",
|
||||
Long: "Commands to list forwarding rules.",
|
||||
RunE: listForwardingRules,
|
||||
}
|
||||
|
||||
func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.GetRules()) == 0 {
|
||||
cmd.Println("No forwarding rules available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
printForwardingRules(cmd, resp.GetRules())
|
||||
return nil
|
||||
}
|
||||
|
||||
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||
cmd.Println("Available forwarding rules:")
|
||||
|
||||
// Sort rules by translated address
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
|
||||
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
|
||||
}
|
||||
if rules[i].GetProtocol() != rules[j].GetProtocol() {
|
||||
return rules[i].GetProtocol() < rules[j].GetProtocol()
|
||||
}
|
||||
|
||||
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
|
||||
})
|
||||
|
||||
var lastIP string
|
||||
for _, rule := range rules {
|
||||
dPort := portToString(rule.GetDestinationPort())
|
||||
tPort := portToString(rule.GetTranslatedPort())
|
||||
if lastIP != rule.GetTranslatedAddress() {
|
||||
lastIP = rule.GetTranslatedAddress()
|
||||
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
|
||||
}
|
||||
|
||||
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
|
||||
}
|
||||
}
|
||||
|
||||
func getFirstPort(portInfo *proto.PortInfo) int {
|
||||
switch v := portInfo.PortSelection.(type) {
|
||||
case *proto.PortInfo_Port:
|
||||
return int(v.Port)
|
||||
case *proto.PortInfo_Range_:
|
||||
return int(v.Range.GetStart())
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func portToString(translatedPort *proto.PortInfo) string {
|
||||
switch v := translatedPort.PortSelection.(type) {
|
||||
case *proto.PortInfo_Port:
|
||||
return fmt.Sprintf("%d", v.Port)
|
||||
case *proto.PortInfo_Range_:
|
||||
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
|
||||
default:
|
||||
return "No port specified"
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "login to the Netbird Management Service (first run)",
|
||||
@@ -127,7 +131,7 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -198,7 +202,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
@@ -212,19 +216,27 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
if noBrowser {
|
||||
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||
} else {
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
|
||||
if !noBrowser {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -145,6 +145,7 @@ func init() {
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
@@ -153,6 +154,8 @@ func init() {
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
|
||||
|
||||
debugCmd.AddCommand(debugBundleCmd)
|
||||
debugCmd.AddCommand(logCmd)
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
@@ -177,7 +180,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
|
||||
@@ -6,13 +6,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -30,7 +34,7 @@ import (
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &mgmt.Config{}
|
||||
config := &types.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -65,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
@@ -88,14 +92,19 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -32,12 +32,16 @@ const (
|
||||
|
||||
const (
|
||||
dnsLabelsFlag = "extra-dns-labels"
|
||||
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
)
|
||||
|
||||
var (
|
||||
foregroundMode bool
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -65,6 +69,9 @@ func init() {
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -349,7 +356,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
|
||||
@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan error, 1)
|
||||
run := make(chan struct{}, 1)
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
run <- err
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-run:
|
||||
if err != nil {
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
||||
}
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
}
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
@@ -10,17 +10,18 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
|
||||
@@ -30,10 +30,8 @@ type entry struct {
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
@@ -41,12 +39,10 @@ type aclManager struct {
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
@@ -79,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -314,9 +311,12 @@ func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
// Inbound is handled by our ACLs, the rest is dropped.
|
||||
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ type Manager struct {
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -166,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
@@ -226,6 +227,22 @@ func (m *Manager) DisableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -10,15 +10,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() iface.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -97,17 +97,17 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -23,22 +24,38 @@ import (
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpPre = "jump-pre"
|
||||
jumpNat = "jump-nat"
|
||||
matchSet = "--match-set"
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
chain string
|
||||
table string
|
||||
rule []string
|
||||
}
|
||||
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
@@ -62,6 +79,7 @@ type router struct {
|
||||
legacyManagement bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
@@ -69,6 +87,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -98,12 +117,17 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -111,7 +135,7 @@ func (r *router) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
@@ -139,9 +163,9 @@ func (r *router) AddRouteFiltering(
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -156,12 +180,12 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.GetRuleID()
|
||||
ruleKey := rule.ID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
@@ -212,6 +236,10 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
@@ -238,6 +266,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
@@ -264,7 +296,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
@@ -277,7 +309,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
@@ -305,7 +337,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
@@ -322,12 +354,16 @@ func (r *router) Reset() error {
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.rules = make(map[string][]string)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -343,9 +379,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTFWDIN, tableFilter},
|
||||
{chainRTFWDOUT, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
@@ -365,16 +403,22 @@ func (r *router) createContainers() error {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTFWDIN, tableFilter},
|
||||
{chainRTFWDOUT, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
} {
|
||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -389,6 +433,57 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
@@ -415,27 +510,6 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
switch chain {
|
||||
case chainRTNAT:
|
||||
return tableNat
|
||||
case chainRTPRE:
|
||||
return tableMangle
|
||||
default:
|
||||
return tableFilter
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
@@ -451,31 +525,46 @@ func (r *router) insertEstablishedRule(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
// Jump to NAT chain
|
||||
// Jump to nat chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat jump rule: %v", err)
|
||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpNat] = natRule
|
||||
r.rules[jumpNatPost] = natRule
|
||||
|
||||
// Jump to prerouting chain
|
||||
// Jump to mangle prerouting chain
|
||||
preRule := []string{"-j", chainRTPRE}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpPre] = preRule
|
||||
r.rules[jumpManglePre] = preRule
|
||||
|
||||
// Jump to nat prerouting chain
|
||||
rdrRule := []string{"-j", chainRTRDR}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
||||
return fmt.Errorf("add nat prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpNatPre] = rdrRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
table := tableNat
|
||||
chain := chainPOSTROUTING
|
||||
if ruleKey == jumpPre {
|
||||
var table, chain string
|
||||
switch ruleKey {
|
||||
case jumpNatPost:
|
||||
table = tableNat
|
||||
chain = chainPOSTROUTING
|
||||
case jumpManglePre:
|
||||
table = tableMangle
|
||||
chain = chainPREROUTING
|
||||
case jumpNatPre:
|
||||
table = tableNat
|
||||
chain = chainPREROUTING
|
||||
default:
|
||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||
@@ -520,6 +609,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -535,6 +626,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -564,6 +656,137 @@ func (r *router) updateState() {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
toDestination := rule.TranslatedAddress.String()
|
||||
switch {
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
// no translated port, use original port
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
// need the "/originalport" suffix to avoid dnat port randomization
|
||||
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
|
||||
proto := strings.ToLower(string(rule.Protocol))
|
||||
|
||||
rules := make(map[string]ruleInfo, 3)
|
||||
|
||||
// DNAT rule
|
||||
dnatRule := []string{
|
||||
"!", "-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "DNAT",
|
||||
"--to-destination", toDestination,
|
||||
}
|
||||
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||
rules[ruleKey+dnatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
// SNAT rule
|
||||
snatRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleKey+snatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTNAT,
|
||||
rule: snatRule,
|
||||
}
|
||||
|
||||
// Forward filtering rule, if fwd policy is DROP
|
||||
forwardRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "ACCEPT",
|
||||
}
|
||||
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleKey+fwdSuffix] = ruleInfo{
|
||||
table: tableFilter,
|
||||
chain: chainRTFWDOUT,
|
||||
rule: forwardRule,
|
||||
}
|
||||
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||
log.Errorf("rollback failed: %v", rollbackErr)
|
||||
}
|
||||
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||
}
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
||||
var merr *multierror.Error
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||
// On rollback error, add to rules map for next cleanup
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
r.updateState()
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
var merr *multierror.Error
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
}
|
||||
|
||||
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
|
||||
@@ -39,12 +39,16 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Now 5 rules:
|
||||
// 1. established rule in forward chain
|
||||
// 2. jump rule to NAT chain
|
||||
// 3. jump rule to PRE chain
|
||||
// 4. static outbound masquerade rule
|
||||
// 5. static return masquerade rule
|
||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||
// 1. established rule forward in
|
||||
// 2. estbalished rule forward out
|
||||
// 3. jump rule to POST nat chain
|
||||
// 4. jump rule to PRE mangle chain
|
||||
// 5. jump rule to PRE nat chain
|
||||
// 6. static outbound masquerade rule
|
||||
// 7. static return masquerade rule
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
@@ -328,18 +332,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
|
||||
@@ -12,6 +12,6 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
@@ -4,21 +4,20 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
if err := ipt.Reset(nil); err != nil {
|
||||
if err := ipt.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ const (
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// GetRuleID returns the rule id
|
||||
GetRuleID() string
|
||||
// ID returns the rule id
|
||||
ID() string
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
@@ -65,13 +65,13 @@ type Manager interface {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -80,7 +80,15 @@ type Manager interface {
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
@@ -94,8 +102,8 @@ type Manager interface {
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
// Close closes the firewall manager
|
||||
Close(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
@@ -105,6 +113,12 @@ type Manager interface {
|
||||
EnableRouting() error
|
||||
|
||||
DisableRouting() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
DeleteDNATRule(Rule) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
27
client/firewall/manager/forward_rule.go
Normal file
27
client/firewall/manager/forward_rule.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// ForwardRule todo figure out better place to this to avoid circular imports
|
||||
type ForwardRule struct {
|
||||
Protocol Protocol
|
||||
DestinationPort Port
|
||||
TranslatedAddress netip.Addr
|
||||
TranslatedPort Port
|
||||
}
|
||||
|
||||
func (r ForwardRule) ID() string {
|
||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||
r.Protocol,
|
||||
r.DestinationPort.String(),
|
||||
r.TranslatedAddress.String(),
|
||||
r.TranslatedPort.String())
|
||||
return id
|
||||
}
|
||||
|
||||
func (r ForwardRule) String() string {
|
||||
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
|
||||
}
|
||||
@@ -1,30 +1,12 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
|
||||
// ProtocolALL cover all supported protocols
|
||||
ProtocolALL Protocol = "all"
|
||||
|
||||
// ProtocolUnknown unknown protocol
|
||||
ProtocolUnknown Protocol = "unknown"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||
type Port struct {
|
||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||
IsRange bool
|
||||
@@ -33,6 +15,25 @@ type Port struct {
|
||||
Values []uint16
|
||||
}
|
||||
|
||||
func NewPort(ports ...int) (*Port, error) {
|
||||
if len(ports) == 0 {
|
||||
return nil, fmt.Errorf("no port provided")
|
||||
}
|
||||
|
||||
ports16 := make([]uint16, len(ports))
|
||||
for i, port := range ports {
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
|
||||
}
|
||||
ports16[i] = uint16(port)
|
||||
}
|
||||
|
||||
return &Port{
|
||||
IsRange: len(ports) > 1,
|
||||
Values: ports16,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// String interface implementation
|
||||
func (p *Port) String() string {
|
||||
var ports string
|
||||
|
||||
19
client/firewall/manager/protocol.go
Normal file
19
client/firewall/manager/protocol.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package manager
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
|
||||
// ProtocolALL cover all supported protocols
|
||||
ProtocolALL Protocol = "all"
|
||||
)
|
||||
@@ -25,9 +25,10 @@ const (
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNamePrerouting = "netbird-rt-prerouting"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
@@ -84,13 +85,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
@@ -102,7 +103,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.GetRuleID())
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
@@ -141,7 +142,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.GetRuleID())
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
@@ -176,7 +177,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.rules, r.GetRuleID())
|
||||
delete(m.rules, r.ID())
|
||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||
|
||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||
@@ -256,7 +257,6 @@ func (m *AclManager) addIOFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
@@ -338,7 +338,7 @@ func (m *AclManager) addIOFiltering(
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||
userData := []byte(ruleId)
|
||||
|
||||
chain := m.chainInputRules
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
@@ -463,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
}
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ const (
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Reset() without needing to store specific rules.
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -242,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -342,6 +343,22 @@ func (m *Manager) Flush() error {
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -16,15 +16,15 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() iface.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
if err := manager.Close(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
|
||||
@@ -14,23 +14,31 @@ import (
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameForward = "FORWARD"
|
||||
tableNat = "nat"
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameForward = "FORWARD"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
@@ -49,16 +57,18 @@ type router struct {
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
r := &router{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -90,6 +100,10 @@ func (r *router) init(workTable *nftables.Table) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -98,7 +112,52 @@ func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
return r.removeAcceptForwardRules()
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
}
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from nat table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Delete rules that have our UserData suffix
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
@@ -133,15 +192,29 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Chain is created by acl manager
|
||||
// TODO: move creation to a common place
|
||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingRdr,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePostrouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
@@ -157,7 +230,83 @@ func (r *router) createContainers() error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||
return errors.New("no mangle chains found")
|
||||
}
|
||||
|
||||
ctNew := getCtNewExprs()
|
||||
preExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
preExprs = append(preExprs, ctNew...)
|
||||
preExprs = append(preExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
preNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: preExprs,
|
||||
}
|
||||
r.conn.AddRule(preNftRule)
|
||||
|
||||
postExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
postExprs = append(postExprs, ctNew...)
|
||||
postExprs = append(postExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
postNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePostrouting],
|
||||
Exprs: postExprs,
|
||||
}
|
||||
r.conn.AddRule(postNftRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -165,6 +314,7 @@ func (r *router) createContainers() error {
|
||||
|
||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -173,7 +323,7 @@ func (r *router) AddRouteFiltering(
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
@@ -281,7 +431,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleKey := rule.GetRuleID()
|
||||
ruleKey := rule.ID()
|
||||
nftRule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
@@ -410,6 +560,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -448,26 +602,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs := getCtNewExprs()
|
||||
exprs = append(exprs,
|
||||
// interface matching
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
@@ -478,7 +616,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
@@ -510,7 +648,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNamePrerouting],
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
@@ -836,6 +974,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -896,6 +1038,269 @@ func (r *router) refreshRulesMap() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.addDnatMasq(rule, protoNum, ruleKey)
|
||||
|
||||
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||
// TODO: find chains with drop policies and add rules there
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush rules: %w", err)
|
||||
}
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
|
||||
dnatExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
}
|
||||
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||
|
||||
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
|
||||
}
|
||||
}
|
||||
|
||||
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleKey + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
switch {
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
return r.handlePortRange(rule)
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
return r.handleAddressOnly(rule)
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
return r.handleSinglePort(rule)
|
||||
default:
|
||||
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 3, nil
|
||||
}
|
||||
|
||||
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
return exprs, 0, 0, nil
|
||||
}
|
||||
|
||||
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 0, nil
|
||||
}
|
||||
|
||||
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.Counter{},
|
||||
&expr.Target{
|
||||
Name: "DNAT",
|
||||
Rev: 2,
|
||||
Info: &xt.NatRange2{
|
||||
NatRange: xt.NatRange{
|
||||
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||
MinPort: rule.TranslatedPort.Values[0],
|
||||
MaxPort: rule.TranslatedPort.Values[1],
|
||||
},
|
||||
BasePort: rule.DestinationPort.Values[0],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: nftables.TableFamilyIPv4,
|
||||
},
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
},
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleKey + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
|
||||
masqExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
|
||||
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||
masqExprs = append(masqExprs, &expr.Masq{})
|
||||
|
||||
masqRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: masqExprs,
|
||||
UserData: []byte(ruleKey + snatSuffix),
|
||||
}
|
||||
r.conn.AddRule(masqRule)
|
||||
r.rules[ruleKey+snatSuffix] = masqRule
|
||||
}
|
||||
|
||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
|
||||
if merr == nil {
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||
var offset uint32
|
||||
@@ -959,15 +1364,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
// Handle port range
|
||||
exprs = append(exprs,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGte,
|
||||
&expr.Range{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpLte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
@@ -993,3 +1394,24 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func getCtNewExprs() []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
if chain.Name == chainNameManglePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
||||
if string(rule.UserData) == ruleKey.ID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if port.IsRange {
|
||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
||||
case *expr.Range:
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
fromPort := binary.BigEndian.Uint16(ex.FromData)
|
||||
toPort := binary.BigEndian.Uint16(ex.ToData)
|
||||
if fromPort == port.Values[0] && toPort == port.Values[1] {
|
||||
portMatchFound = true
|
||||
}
|
||||
} else {
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if !port.IsRange {
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||
for _, p := range port.Values {
|
||||
if uint16(p) == portValue {
|
||||
if p == portValue {
|
||||
portMatchFound = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -16,6 +16,6 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
@@ -3,21 +3,20 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
if err := nft.Reset(nil); err != nil {
|
||||
if err := nft.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset nftables manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,39 +4,36 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
@@ -48,7 +45,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -22,30 +23,30 @@ const (
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
||||
@@ -3,14 +3,14 @@ package common
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
// common.go
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
FlowId uuid.UUID
|
||||
Direction nftypes.Direction
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
lastSeen atomic.Int64
|
||||
PacketsTx atomic.Uint64
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// UpdateCounters safely updates the packet and byte counters
|
||||
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||
if direction == nftypes.Egress {
|
||||
b.PacketsTx.Add(1)
|
||||
b.BytesTx.Add(uint64(bytes))
|
||||
} else {
|
||||
b.PacketsRx.Add(1)
|
||||
b.BytesRx.Add(uint64(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||
return time.Since(lastSeen) > timeout
|
||||
}
|
||||
|
||||
// IPAddr is a fixed-size IP address to avoid allocations
|
||||
type IPAddr [16]byte
|
||||
|
||||
// MakeIPAddr creates an IPAddr from net.IP
|
||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
||||
// Optimization: check for v4 first as it's more common
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
copy(addr[12:], ip4)
|
||||
} else {
|
||||
copy(addr[:], ip.To16())
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// ConnKey uniquely identifies a connection
|
||||
type ConnKey struct {
|
||||
SrcIP IPAddr
|
||||
DstIP IPAddr
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// makeConnKey creates a connection key
|
||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||
return ConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateIPs checks if IPs match without allocation
|
||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
||||
if ip4 := pktIP.To4(); ip4 != nil {
|
||||
// Compare IPv4 addresses (last 4 bytes)
|
||||
for i := 0; i < 4; i++ {
|
||||
if connIP[12+i] != ip4[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Compare full IPv6 addresses
|
||||
ip6 := pktIP.To16()
|
||||
for i := 0; i < 16; i++ {
|
||||
if connIP[i] != ip6[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
||||
type PreallocatedIPs struct {
|
||||
sync.Pool
|
||||
}
|
||||
|
||||
// NewPreallocatedIPs creates a new IP pool
|
||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
||||
return &PreallocatedIPs{
|
||||
Pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
ip := make(net.IP, 16)
|
||||
return &ip
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an IP from the pool
|
||||
func (p *PreallocatedIPs) Get() net.IP {
|
||||
return *p.Pool.Get().(*net.IP)
|
||||
}
|
||||
|
||||
// Put returns an IP to the pool
|
||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
||||
p.Pool.Put(&ip)
|
||||
}
|
||||
|
||||
// copyIP copies an IP address efficiently
|
||||
func copyIP(dst, src net.IP) {
|
||||
if len(src) == 16 {
|
||||
copy(dst, src)
|
||||
} else {
|
||||
// Handle IPv4
|
||||
copy(dst[12:], src.To4())
|
||||
}
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
@@ -1,94 +1,66 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = MakeIPAddr(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ValidateIPs", func(b *testing.B) {
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.1")
|
||||
addr := MakeIPAddr(ip1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ValidateIPs(addr, ip2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IPPool", func(b *testing.B) {
|
||||
pool := NewPreallocatedIPs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ip := pool.Get()
|
||||
pool.Put(ip)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,18 +23,20 @@ const (
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
type ICMPConnKey struct {
|
||||
// Supports both IPv4 and IPv6
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
Sequence uint16 // ICMP sequence number
|
||||
ID uint16 // ICMP identifier
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
ID uint16
|
||||
}
|
||||
|
||||
func (i ICMPConnKey) String() string {
|
||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
type ICMPConnTrack struct {
|
||||
BaseConnTrack
|
||||
Sequence uint16
|
||||
ID uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
}
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
@@ -39,131 +45,202 @@ type ICMPTracker struct {
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
tickerCancel: cancel,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
},
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||
key := ICMPConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
ID: id,
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
return key, true
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
return key, false
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine() {
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
typ, code := typecode.Type(), typecode.Code()
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
|
||||
conn := &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := ICMPConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
ID: id,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.tickerCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
t.tickerCancel()
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// makeICMPKey creates an ICMP connection key
|
||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||
return ICMPConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
ICMPCode: conn.ICMPCode,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
}
|
||||
if direction == nftypes.Ingress {
|
||||
fields.RxPackets = 1
|
||||
fields.RxBytes = uint64(size)
|
||||
} else {
|
||||
fields.TxPackets = 1
|
||||
fields.TxBytes = uint64(size)
|
||||
}
|
||||
t.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,12 +3,16 @@ package conntrack
|
||||
// TODO: Send RST packets for invalid/timed-out connections
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,11 +23,11 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPFin uint8 = 0x01
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPRst uint8 = 0x04
|
||||
TCPPush uint8 = 0x08
|
||||
TCPAck uint8 = 0x10
|
||||
TCPUrg uint8 = 0x20
|
||||
)
|
||||
|
||||
@@ -37,7 +41,36 @@ const (
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int
|
||||
type TCPState int32
|
||||
|
||||
func (s TCPState) String() string {
|
||||
switch s {
|
||||
case TCPStateNew:
|
||||
return "New"
|
||||
case TCPStateSynSent:
|
||||
return "SYN Sent"
|
||||
case TCPStateSynReceived:
|
||||
return "SYN Received"
|
||||
case TCPStateEstablished:
|
||||
return "Established"
|
||||
case TCPStateFinWait1:
|
||||
return "FIN Wait 1"
|
||||
case TCPStateFinWait2:
|
||||
return "FIN Wait 2"
|
||||
case TCPStateClosing:
|
||||
return "Closing"
|
||||
case TCPStateTimeWait:
|
||||
return "Time Wait"
|
||||
case TCPStateCloseWait:
|
||||
return "Close Wait"
|
||||
case TCPStateLastAck:
|
||||
return "Last ACK"
|
||||
case TCPStateClosed:
|
||||
return "Closed"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
TCPStateNew TCPState = iota
|
||||
@@ -53,30 +86,38 @@ const (
|
||||
TCPStateClosed
|
||||
)
|
||||
|
||||
// TCPConnKey uniquely identifies a TCP connection
|
||||
type TCPConnKey struct {
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
sync.RWMutex
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
state atomic.Int32
|
||||
tombstone atomic.Bool
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
// GetState safely retrieves the current state
|
||||
func (t *TCPConnTrack) GetState() TCPState {
|
||||
return TCPState(t.state.Load())
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
// SetState safely updates the current state
|
||||
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||
t.state.Store(int32(state))
|
||||
}
|
||||
|
||||
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||
}
|
||||
|
||||
// IsTombstone safely checks if the connection is marked for deletion
|
||||
func (t *TCPConnTrack) IsTombstone() bool {
|
||||
return t.tombstone.Load()
|
||||
}
|
||||
|
||||
// SetTombstone safely marks the connection for deletion
|
||||
func (t *TCPConnTrack) SetTombstone() {
|
||||
t.tombstone.Store(true)
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
@@ -85,192 +126,234 @@ type TCPTracker struct {
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
done chan struct{}
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
ipPool *PreallocatedIPs
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||
waitTimeout := TimeWaitTimeout
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTCPTimeout
|
||||
} else {
|
||||
waitTimeout = timeout / 45
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
// Use preallocated IPs
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
// Lock individual connection for state update
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
// Handle RST packets
|
||||
if flags&TCPRst != 0 {
|
||||
conn.Lock()
|
||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
return true
|
||||
}
|
||||
conn.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
|
||||
return isEstablished || isValidState
|
||||
return key, false
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
switch conn.State {
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
|
||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.IsTombstone() {
|
||||
return false
|
||||
}
|
||||
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
conn.State = TCPStateSynSent
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateSynReceived
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateFinWait1
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
conn.State = TCPStateCloseWait
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
// Simultaneous close - both sides sent FIN
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateLastAck
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
|
||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||
case TCPStateSynSent:
|
||||
// TODO: support simultaneous open
|
||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||
case TCPStateSynReceived:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPRst != 0 {
|
||||
return true
|
||||
}
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateFinWait1:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
@@ -307,20 +394,20 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine() {
|
||||
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -331,39 +418,43 @@ func (t *TCPTracker) cleanup() {
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.IsTombstone() {
|
||||
// Clean up tombstoned connections without sending an event
|
||||
delete(t.connections, key)
|
||||
continue
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
timeout = TimeWaitTimeout
|
||||
case conn.IsEstablished():
|
||||
currentState := conn.GetState()
|
||||
switch currentState {
|
||||
case TCPStateTimeWait:
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
default:
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
lastSeen := conn.GetLastSeen()
|
||||
if time.Since(lastSeen) > timeout {
|
||||
// Return IPs to pool
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if currentState != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
t.tickerCancel()
|
||||
|
||||
// Clean up all remaining IPs
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
@@ -381,3 +472,21 @@ func isValidFlagCombination(flags uint8) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,19 +1,20 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -58,7 +59,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||
})
|
||||
}
|
||||
@@ -76,17 +77,17 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Send initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive SYN-ACK
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
// Send ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Test data transfer
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||
require.True(t, valid, "Data should be allowed after handshake")
|
||||
},
|
||||
},
|
||||
@@ -99,18 +100,18 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "ACK for FIN should be allowed")
|
||||
|
||||
// Receive FIN from other side
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "FIN should be allowed")
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -122,11 +123,8 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -138,13 +136,13 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Both sides send FIN+ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||
|
||||
// Both sides send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "Final ACKs should be allowed")
|
||||
},
|
||||
},
|
||||
@@ -154,7 +152,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@@ -162,11 +160,11 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -181,12 +179,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST in established",
|
||||
setupState: func() {
|
||||
// Establish connection first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
},
|
||||
wantValid: true,
|
||||
desc: "Should accept RST for established connection",
|
||||
@@ -195,7 +193,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST without connection",
|
||||
setupState: func() {},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
},
|
||||
wantValid: false,
|
||||
desc: "Should reject RST without connection",
|
||||
@@ -208,101 +206,455 @@ func TestRSTHandling(t *testing.T) {
|
||||
tt.sendRST()
|
||||
|
||||
// Verify connection state is as expected
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateClosed, conn.State)
|
||||
require.False(t, conn.IsEstablished())
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPRetransmissions(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test SYN retransmission
|
||||
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||
// Initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Retransmit SYN (should not affect the state machine)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Verify we're still in SYN-SENT state
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// Complete the handshake
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Verify we're in ESTABLISHED state
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test ACK retransmission in established state
|
||||
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Retransmit ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test FIN retransmission
|
||||
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Retransmit FIN (should not change state)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPDataTransfer(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Data Transfer", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Receive ACK for data
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||
require.True(t, valid)
|
||||
|
||||
// Receive data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send ACK for received data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test half-closed connection: local end closes, remote end continues sending data
|
||||
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Remote end can still send data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||
require.True(t, valid)
|
||||
|
||||
// We can still ACK their data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Receive FIN from remote end
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
})
|
||||
|
||||
// Test half-closed connection: remote end closes, local end continues sending data
|
||||
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Receive FIN from remote
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// We can still send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Remote can still ACK our data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send our FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Receive final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPAbnormalSequences(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test handling of unsolicited RST in various states
|
||||
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||
// Send SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive unsolicited RST (without proper ACK)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||
|
||||
// Receive RST with proper ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
require.True(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Connection Timeout", func(t *testing.T) {
|
||||
// Establish a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Wait for the connection to timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
// Force cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after timeout")
|
||||
})
|
||||
|
||||
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Complete the connection close to enter TIME_WAIT
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||
// For the test, we're using a short timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSynFlood(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
basePort := uint16(10000)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Create a large number of SYN packets to simulate a SYN flood
|
||||
for i := uint16(0); i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Check that we're tracking all connections
|
||||
require.Equal(t, 1000, len(tracker.connections))
|
||||
|
||||
// Now simulate SYN timeout
|
||||
var oldConns int
|
||||
tracker.mutex.Lock()
|
||||
for _, conn := range tracker.connections {
|
||||
if conn.GetState() == TCPStateSynSent {
|
||||
// Make the connection appear old
|
||||
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||
oldConns++
|
||||
}
|
||||
}
|
||||
tracker.mutex.Unlock()
|
||||
require.Equal(t, 1000, oldConns)
|
||||
|
||||
// Run cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Check that stale connections were cleaned up
|
||||
require.Equal(t, 0, len(tracker.connections))
|
||||
}
|
||||
|
||||
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||
clientPort := uint16(12345)
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
DstIP: serverIP,
|
||||
SrcPort: clientPort,
|
||||
DstPort: serverPort,
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
conn := tracker.connections[key]
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||
|
||||
// 2. Server sends SYN-ACK response
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
// Server sends data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,6 +22,8 @@ const (
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
type UDPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
}
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
@@ -26,89 +32,126 @@ type UDPTracker struct {
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
tickerCancel: cancel,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %v", conn)
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
return key, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine() {
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -120,44 +163,58 @@ func (t *UDPTracker) cleanup() {
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *UDPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
t.tickerCancel()
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
return conn, true
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// Timeout returns the configured timeout duration for the tracker
|
||||
func (t *UDPTracker) Timeout() time.Duration {
|
||||
return t.timeout
|
||||
}
|
||||
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -29,54 +30,59 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout, logger)
|
||||
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
assert.NotNil(t, tracker.cleanupTicker)
|
||||
assert.NotNil(t, tracker.done)
|
||||
assert.NotNil(t, tracker.tickerCancel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := tracker.connections[key]
|
||||
require.True(t, exists)
|
||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1*time.Second, logger)
|
||||
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Track outbound connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
sleep time.Duration
|
||||
@@ -93,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid source IP",
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
@@ -103,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
{
|
||||
name: "invalid destination IP",
|
||||
srcIP: dstIP,
|
||||
dstIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
@@ -143,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
if tt.sleep > 0 {
|
||||
time.Sleep(tt.sleep)
|
||||
}
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@@ -154,42 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
timeout := 50 * time.Millisecond
|
||||
cleanupInterval := 25 * time.Millisecond
|
||||
|
||||
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||
defer tickerCancel()
|
||||
|
||||
// Create tracker with custom cleanup interval
|
||||
tracker := &UDPTracker{
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
tickerCancel: tickerCancel,
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
|
||||
// Add some connections
|
||||
connections := []struct {
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.2"),
|
||||
dstIP: net.ParseIP("192.168.1.3"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
},
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.5"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||
srcPort: 12346,
|
||||
dstPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
@@ -211,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,6 +30,7 @@ const (
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
@@ -38,7 +40,7 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger),
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
|
||||
@@ -3,14 +3,30 @@ package forwarder
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return true
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
return true
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
// TODO: get packets/bytes
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,24 +5,38 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep)
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
@@ -16,6 +18,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,15 +31,17 @@ type udpPacketConn struct {
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
ep tcpip.Endpoint
|
||||
flowID uuid.UUID
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
@@ -44,13 +49,14 @@ type idleConn struct {
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, mtu)
|
||||
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
conn.ep.Close()
|
||||
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||
}
|
||||
|
||||
idle.conn.ep.Close()
|
||||
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||
return
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if epErr != nil {
|
||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
ep: ep,
|
||||
flowID: flowID,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||
}
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
c.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -13,8 +14,13 @@ import (
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
@@ -26,39 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
ipStr := ip.String()
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -76,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -89,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
@@ -122,13 +148,13 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
}
|
||||
|
||||
return false
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
@@ -2,90 +2,103 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
setupAddr wgaddr.Address
|
||||
testIP netip.Addr
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
@@ -95,7 +108,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
@@ -174,7 +187,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
@@ -191,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
@@ -247,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -256,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
@@ -13,13 +13,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
maxBatchSize = 1024 * 16
|
||||
maxMessageSize = 1024 * 2
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
logChannelSize = 1000
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
type logMessage struct {
|
||||
level Level
|
||||
format string
|
||||
args []any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
msgChannel chan logMessage
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
output: logrusLogger.Out,
|
||||
msgChannel: make(chan logMessage, logChannelSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
New: func() any {
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logrusLevel := logrusLogger.GetLevel()
|
||||
l.level.Store(uint32(logrusLevel))
|
||||
level := levelStrings[Level(logrusLevel)]
|
||||
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
// SetLevel sets the logging level
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
func (l *Logger) log(level Level, format string, args ...any) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||
default:
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
_, _ = l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
// Error logs a message at error level
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
// Warn logs a message at warning level
|
||||
func (l *Logger) Warn(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
// Info logs a message at info level
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
// Debug logs a message at debug level
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
// Trace logs a message at trace level
|
||||
func (l *Logger) Trace(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||
*buf = (*buf)[:0]
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
var msg string
|
||||
if len(args) > 0 {
|
||||
msg = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
msg = format
|
||||
}
|
||||
*buf = append(*buf, msg...)
|
||||
*buf = append(*buf, '\n')
|
||||
|
||||
if len(*buf) > maxMessageSize {
|
||||
*buf = (*buf)[:maxMessageSize]
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage handles a single log message and adds it to the buffer
|
||||
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
defer l.bufPool.Put(bufp)
|
||||
|
||||
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||
|
||||
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||
_, _ = l.output.Write(*buffer)
|
||||
*buffer = (*buffer)[:0]
|
||||
}
|
||||
*buffer = append(*buffer, *bufp...)
|
||||
}
|
||||
|
||||
// flushBuffer writes the accumulated buffer to output
|
||||
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||
if len(*buffer) > 0 {
|
||||
_, _ = l.output.Write(*buffer)
|
||||
*buffer = (*buffer)[:0]
|
||||
}
|
||||
}
|
||||
|
||||
// processBatch processes as many messages as possible without blocking
|
||||
func (l *Logger) processBatch(buffer *[]byte) {
|
||||
for len(*buffer) < maxBatchSize {
|
||||
select {
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, buffer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, buffer)
|
||||
case <-ctx.Done():
|
||||
l.flushBuffer(buffer)
|
||||
return
|
||||
}
|
||||
|
||||
if len(l.msgChannel) == 0 {
|
||||
l.flushBuffer(buffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker is the main goroutine that processes log messages
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
buffer := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
l.handleShutdown(&buffer)
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
_, _ = l.output.Write(buf[:n])
|
||||
l.flushBuffer(&buffer)
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, &buffer)
|
||||
l.processBatch(&buffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package log_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
type discard struct{}
|
||||
|
||||
func (d *discard) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func BenchmarkLogger(b *testing.B) {
|
||||
simpleMessage := "Connection established"
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4 // TCPStateEstablished
|
||||
|
||||
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||
protocol := "TCP"
|
||||
direction := "outbound"
|
||||
flags := uint16(0x18) // ACK + PSH
|
||||
sequence := uint32(123456789)
|
||||
acknowledged := uint32(987654321)
|
||||
payloadSize := 1460
|
||||
fragmented := false
|
||||
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||
|
||||
b.Run("SimpleMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(simpleMessage)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ComplexMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||
func BenchmarkLoggerParallel(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||
func BenchmarkLoggerBurst(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < 100; j++ {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTestLogger() *log.Logger {
|
||||
logrusLogger := logrus.New()
|
||||
logrusLogger.SetOutput(&discard{})
|
||||
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||
return log.NewFromLogrus(logrusLogger)
|
||||
}
|
||||
|
||||
func cleanupLogger(logger *log.Logger) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = logger.Stop(ctx)
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -12,25 +11,26 @@ import (
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
comment string
|
||||
|
||||
udpHook func([]byte) bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *PeerRule) GetRuleID() string {
|
||||
// ID returns the rule id
|
||||
func (r *PeerRule) ID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
@@ -39,7 +39,7 @@ type RouteRule struct {
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *RouteRule) GetRuleID() string {
|
||||
// ID returns the rule id
|
||||
func (r *RouteRule) ID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
SourceIP netip.Addr
|
||||
DestinationIP netip.Addr
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter {
|
||||
if m.nativeRouter.Load() {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
|
||||
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
strRuleId := "<no id>"
|
||||
if ruleId != nil {
|
||||
strRuleId = string(ruleId)
|
||||
}
|
||||
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||
if blocked {
|
||||
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||
trace.AddResult(StagePeerACL, msg, false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
msg := "Allowed by peer ACL rules"
|
||||
if blocked {
|
||||
msg = "Blocked by peer ACL rules"
|
||||
}
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
trace.AddResult(StagePeerACL, msg, true)
|
||||
|
||||
// Handle netstack mode
|
||||
if m.netstack {
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
switch {
|
||||
case !m.localForwarding:
|
||||
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||
case m.forwarder.Load() != nil:
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
default:
|
||||
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
// In normal mode, packets are allowed through for local delivery
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled {
|
||||
if !m.routingEnabled.Load() {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto := getProtocolFromPacket(d)
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
msg := "Allowed by route ACLs"
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
strId = "<no id>"
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||
if !allowed {
|
||||
msg = "Blocked by route ACLs"
|
||||
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder != nil {
|
||||
if allowed && m.forwarder.Load() != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||
t.Logf("Trace results: %v", trace.Results)
|
||||
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||
for _, result := range trace.Results {
|
||||
actualStages = append(actualStages, result.Stage)
|
||||
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||
}
|
||||
|
||||
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||
lastResult := trace.Results[len(trace.Results)-1]
|
||||
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||
}
|
||||
|
||||
func TestTracePacket(t *testing.T) {
|
||||
setupTracerTest := func(statefulMode bool) *Manager {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
m.stateful = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||
builder := &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: protocol,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Direction: direction,
|
||||
}
|
||||
|
||||
if protocol == "tcp" {
|
||||
builder.TCPState = &TCPState{SYN: true}
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||
return &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: "icmp",
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Direction: direction,
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(*Manager)
|
||||
packetBuilder func() *PacketBuilder
|
||||
expectedStages []PacketStage
|
||||
expectedAllow bool
|
||||
}{
|
||||
{
|
||||
name: "LocalTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = true
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithoutForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_NativeRouter",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_RoutingDisabled",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(false)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "ConnectionTracking_Hit",
|
||||
setup: func(m *Manager) {
|
||||
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||
return pb
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "OutboundTraffic",
|
||||
setup: func(m *Manager) {
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPEchoRequest",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPDestinationUnreachable",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithoutHook",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "StatefulDisabled_NoTracking",
|
||||
setup: func(m *Manager) {
|
||||
m.stateful = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
m := setupTracerTest(true)
|
||||
|
||||
tc.setup(m)
|
||||
|
||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||
"100.10.0.100 should be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||
"172.17.0.2 should not be recognized as a local IP")
|
||||
|
||||
pb := tc.packetBuilder()
|
||||
|
||||
trace, err := m.TracePacketFromBuilder(pb)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyTraceStages(t, trace, tc.expectedStages)
|
||||
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -42,6 +44,8 @@ const (
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
@@ -63,9 +67,9 @@ func (r RouteRules) Sort() {
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
@@ -77,9 +81,9 @@ type Manager struct {
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
routingEnabled atomic.Bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter bool
|
||||
nativeRouter atomic.Bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
@@ -92,8 +96,9 @@ type Manager struct {
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -110,16 +115,16 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -146,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
|
||||
return disableConntrack, enableLocalForwarding
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
@@ -164,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@@ -183,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
@@ -206,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder == nil {
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
@@ -216,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.ProtocolALL,
|
||||
@@ -249,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
@@ -270,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
@@ -291,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder != nil {
|
||||
if m.forwarder.Load() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled = false
|
||||
m.routingEnabled.Store(false)
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled = false
|
||||
m.routingEnabled.Store(false)
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
m.forwarder.Store(forwarder)
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
@@ -324,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
|
||||
@@ -335,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return nil
|
||||
@@ -346,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
mgmtId: id,
|
||||
ip: i,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
if i.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
r.ip = ipNormalized
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
@@ -389,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -405,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
@@ -423,21 +436,23 @@ func (m *Manager) AddRouteFiltering(
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.GetRuleID()
|
||||
ruleID := rule.ID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
})
|
||||
@@ -459,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
delete(m.incomingRules[r.ip], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -478,14 +493,30 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.processOutgoingHooks(packetData)
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
}
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData)
|
||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||
return m.dropFilter(packetData, size)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
@@ -493,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -509,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return false
|
||||
}
|
||||
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return d.ip4.SrcIP, d.ip4.DstIP
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
case layers.LayerTypeIPv6:
|
||||
return d.ip6.SrcIP, d.ip6.DstIP
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
default:
|
||||
return nil, nil
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
flags,
|
||||
)
|
||||
}
|
||||
|
||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
var flags uint8
|
||||
if tcp.SYN {
|
||||
@@ -578,100 +591,153 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
m.udpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
)
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||
}
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||
m.icmpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
)
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
valid, fragment := m.isValidPacket(d, packetData)
|
||||
if !valid {
|
||||
return true
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
return false
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
if blocked {
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
RxPackets: 1,
|
||||
RxBytes: uint64(size),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack {
|
||||
if m.netstack && m.localForwarding {
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
// track inbound packets to get the correct direction and session id for flows
|
||||
m.trackInbound(d, srcIP, dstIP, ruleID, size)
|
||||
|
||||
// pass to either native or virtual stack (to be picked up by listeners)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
if !m.localForwarding {
|
||||
// pass to virtual tcp/ip stack to be picked up by listeners
|
||||
return false
|
||||
}
|
||||
|
||||
if m.forwarder == nil {
|
||||
fwd := m.forwarder.Load()
|
||||
if fwd == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
|
||||
@@ -681,47 +747,68 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
if !m.routingEnabled.Load() {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter {
|
||||
if m.nativeRouter.Load() {
|
||||
m.trackInbound(d, srcIP, dstIP, nil, size)
|
||||
return false
|
||||
}
|
||||
|
||||
proto := getProtocolFromPacket(d)
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
RxPackets: 1,
|
||||
RxBytes: uint64(size),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
fwd := m.forwarder.Load()
|
||||
if fwd == nil {
|
||||
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
||||
} else {
|
||||
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject routed packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP
|
||||
return firewall.ProtocolTCP, nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP
|
||||
return firewall.ProtocolUDP, nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP
|
||||
return firewall.ProtocolICMP, nftypes.ICMP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||
}
|
||||
}
|
||||
|
||||
@@ -736,20 +823,35 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
// isValidPacket checks if the packet is valid.
|
||||
// It returns true, false if the packet is valid and not a fragment.
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||
return false
|
||||
return false, false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false
|
||||
l := len(d.decoded)
|
||||
|
||||
// L3 and L4 are mandatory
|
||||
if l >= 2 {
|
||||
return true, false
|
||||
}
|
||||
return true
|
||||
|
||||
// Fragments are also valid
|
||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
ip4 := d.ip4
|
||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return m.tcpTracker.IsValidInbound(
|
||||
@@ -758,6 +860,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
getTCPFlags(&d.tcp),
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeUDP:
|
||||
@@ -766,6 +869,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeICMPv4:
|
||||
@@ -773,8 +877,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
d.icmp4.TypeCode.Type(),
|
||||
size,
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
@@ -794,25 +898,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
if m.isSpecialICMP(d) {
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return filter
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||
return filter
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||
return filter
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
// Default policy: DROP ALL
|
||||
return true
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
@@ -832,15 +938,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
@@ -850,39 +956,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.udpHook(packetData), true
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
return rule.action == firewall.ActionAccept
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
@@ -922,36 +1025,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@@ -999,20 +1098,21 @@ func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.forwarder == nil {
|
||||
fwder := m.forwarder.Load()
|
||||
if fwder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
fwder.Stop()
|
||||
m.forwarder.Store(nil)
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
|
||||
@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.ActionAccept, "", "allow all")
|
||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.ActionDrop, "", "default drop")
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -158,9 +169,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -203,9 +214,9 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.processOutgoingHooks(testOut)
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn)
|
||||
manager.dropFilter(testIn, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -251,9 +262,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -450,9 +461,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Setup scenario
|
||||
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack)
|
||||
manager.dropFilter(synack, 0)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -577,9 +588,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -668,9 +676,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -787,9 +792,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -875,9 +877,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
nil,
|
||||
r.port,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
@@ -26,20 +26,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false)
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
tc.ruleProto,
|
||||
tc.ruleSrcPort,
|
||||
tc.ruleDstPort,
|
||||
tc.ruleAction,
|
||||
"",
|
||||
tc.name,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
@@ -302,15 +302,15 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false)
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled)
|
||||
require.False(tb, manager.nativeRouter)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
|
||||
tb.Cleanup(func() {
|
||||
require.NoError(tb, manager.Reset(nil))
|
||||
require.NoError(tb, manager.Close(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
tc.rule.dest,
|
||||
tc.rule.proto,
|
||||
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
var rules []fw.Rule
|
||||
for _, r := range tc.rules {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
|
||||
for i, p := range tc.packets {
|
||||
srcIP := net.ParseIP(p.srcIP)
|
||||
dstIP := net.ParseIP(p.dstIP)
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,15 +17,17 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
@@ -50,9 +53,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
func (i *IFaceMock) Address() wgaddr.Address {
|
||||
if i.AddressFunc == nil {
|
||||
return iface.WGAddress{}
|
||||
return wgaddr.Address{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -62,7 +65,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -82,7 +85,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -92,9 +95,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -116,26 +118,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
ip := netip.MustParseAddr("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -149,7 +150,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -160,7 +161,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
@@ -169,7 +170,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: net.IPv4(10, 168, 0, 1),
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
@@ -177,7 +178,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: net.IPv6loopback,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
@@ -187,18 +188,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
@@ -206,12 +207,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
@@ -236,7 +237,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -246,15 +247,14 @@ func TestManagerReset(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Reset(nil)
|
||||
err = m.Close(nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
@@ -268,8 +268,8 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
@@ -279,7 +279,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -292,9 +292,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -328,12 +327,12 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
|
||||
if err = m.Reset(nil); err != nil {
|
||||
if err = m.Close(nil); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -347,17 +346,17 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false)
|
||||
manager, err := Create(iface, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
@@ -393,7 +392,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -401,9 +400,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.decoders = sync.Pool{
|
||||
@@ -423,7 +422,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
net.ParseIP("100.10.0.100"),
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
hookCalled = true
|
||||
@@ -458,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.processOutgoingHooks(buf.Bytes())
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -468,7 +467,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.processOutgoingHooks(buf.Bytes())
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -479,12 +478,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false)
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
if err := manager.Close(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -494,7 +493,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -506,7 +505,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -515,7 +514,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -530,12 +529,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set up packet parameters
|
||||
srcIP := net.ParseIP("100.10.0.1")
|
||||
dstIP := net.ParseIP("100.10.0.100")
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||
srcPort := uint16(51334)
|
||||
dstPort := uint16(53)
|
||||
|
||||
@@ -543,8 +542,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
outboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcIP: srcIP.AsSlice(),
|
||||
DstIP: dstIP.AsSlice(),
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
outboundUDP := &layers.UDP{
|
||||
@@ -569,15 +568,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
||||
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||
|
||||
@@ -585,8 +584,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
inboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: dstIP, // Original destination is now source
|
||||
DstIP: srcIP, // Original source is now destination
|
||||
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
inboundUDP := &layers.UDP{
|
||||
@@ -636,7 +635,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -685,7 +684,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -707,7 +706,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
@@ -14,6 +13,8 @@ import (
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -52,9 +53,10 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// force IPv4
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeAddr] = conn
|
||||
b.endpoints[fakeIP] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
|
||||
return fakeUDPAddr, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
log.Warnf("failed to convert IP to netip.Addr")
|
||||
return
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
b.endpointsMu.Lock()
|
||||
defer b.endpointsMu.Unlock()
|
||||
delete(b.endpoints, fakeAddr)
|
||||
|
||||
delete(b.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
newAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||
Port: peerAddress.Port,
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
|
||||
mux := &UDPMuxDefault{
|
||||
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
|
||||
buf: make([]byte, size),
|
||||
}
|
||||
}
|
||||
|
||||
func getLogger() logging.LeveledLogger {
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
fac.Writer = log.StandardLogger().Writer()
|
||||
return fac.NewLogger("ice")
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
@@ -41,12 +43,13 @@ type UniversalUDPMuxParams struct {
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
}
|
||||
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
if params.XORMappedAddrCacheTTL == 0 {
|
||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
@@ -118,6 +122,7 @@ type udpConn struct {
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
|
||||
@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.NetbirdFwmark
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
|
||||
@@ -13,11 +13,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
type WGTunDevice struct {
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -31,7 +32,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
|
||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) WgAddress() WGAddress {
|
||||
func (t *WGTunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,11 +13,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -29,7 +30,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -10,16 +11,16 @@ import (
|
||||
// PacketFilter interface for firewall abilities
|
||||
type PacketFilter interface {
|
||||
// DropOutgoing filter outgoing packets from host to external destinations
|
||||
DropOutgoing(packetData []byte) bool
|
||||
DropOutgoing(packetData []byte, size int) bool
|
||||
|
||||
// DropIncoming filter incoming packets from external sources to host
|
||||
DropIncoming(packetData []byte) bool
|
||||
DropIncoming(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.DropIncoming(buf[offset:]) {
|
||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
return 1, nil
|
||||
})
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
|
||||
@@ -14,11 +14,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
iceBind *bind.ICEBind
|
||||
@@ -30,7 +31,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
||||
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,12 +14,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
wgPort int
|
||||
key string
|
||||
mtu int
|
||||
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
|
||||
filterFn bind.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &TunKernelDevice{
|
||||
ctx: ctx,
|
||||
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: rawSock,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
UDPConn: rawSock,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
}
|
||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||
go mux.ReadFromConn(t.ctx)
|
||||
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return t.udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
|
||||
return closErr
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) WgAddress() WGAddress {
|
||||
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
return &TunNetstackDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
|
||||
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) WgAddress() WGAddress {
|
||||
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type USPDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -28,7 +29,7 @@ type USPDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
return &USPDevice{
|
||||
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) WgAddress() WGAddress {
|
||||
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,13 +13,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -32,7 +33,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/freebsd"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
link, err := freebsd.LinkByName(l.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(l, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
@@ -28,8 +29,6 @@ const (
|
||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||
)
|
||||
|
||||
type WGAddress = device.WGAddress
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
Free() error
|
||||
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
|
||||
}
|
||||
|
||||
// Address returns the interface address
|
||||
func (w *WGIface) Address() device.WGAddress {
|
||||
func (w *WGIface) Address() wgaddr.Address {
|
||||
return w.tun.WgAddress()
|
||||
}
|
||||
|
||||
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
addr, err := device.ParseWGAddress(newAddr)
|
||||
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,17 +3,18 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -6,17 +6,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -5,17 +5,18 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
return wgIFace, nil
|
||||
}
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
|
||||
@@ -4,16 +4,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -6,6 +6,7 @@ package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
}
|
||||
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||
}
|
||||
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
|
||||
@@ -55,7 +55,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
|
||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
|
||||
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||
}
|
||||
if skipProxy {
|
||||
return nsTunDev, tunNet, nil
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
package device
|
||||
package wgaddr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// WGAddress WireGuard parsed address
|
||||
type WGAddress struct {
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (WGAddress, error) {
|
||||
func ParseWGAddress(address string) (Address, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
if err != nil {
|
||||
return WGAddress{}, err
|
||||
return Address{}, err
|
||||
}
|
||||
return WGAddress{
|
||||
return Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr WGAddress) String() string {
|
||||
func (addr Address) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -16,13 +17,13 @@ import (
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
|
||||
wgAddr *net.UDPAddr
|
||||
wgEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
fakeNetIP *netip.AddrPort
|
||||
wgBindEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
@@ -33,20 +34,24 @@ type ProxyBind struct {
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
||||
fakeNetIP, err := fakeAddress(nbAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.wgAddr = addr
|
||||
p.wgEndpoint = addrToEndpoint(addr)
|
||||
p.fakeNetIP = fakeNetIP
|
||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return err
|
||||
return nil
|
||||
|
||||
}
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgAddr
|
||||
return &net.UDPAddr{
|
||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||
Port: int(p.fakeNetIP.Port()),
|
||||
Zone: p.fakeNetIP.Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Work() {
|
||||
@@ -54,6 +59,8 @@ func (p *ProxyBind) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
@@ -93,7 +100,7 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
@@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgEndpoint,
|
||||
Endpoint: p.wgBindEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
@@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse new IP: %w", err)
|
||||
}
|
||||
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
|
||||
!define INSTALLER_NAME "netbird-installer.exe"
|
||||
!define MAIN_APP_EXE "Netbird"
|
||||
!define ICON "ui\\netbird.ico"
|
||||
!define BANNER "ui\\banner.bmp"
|
||||
!define ICON "ui\\assets\\netbird.ico"
|
||||
!define BANNER "ui\\build\\banner.bmp"
|
||||
!define LICENSE_DATA "..\\LICENSE"
|
||||
|
||||
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
|
||||
@@ -22,6 +22,8 @@
|
||||
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
||||
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
||||
|
||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||
|
||||
Unicode True
|
||||
|
||||
######################################################################
|
||||
@@ -68,6 +70,9 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_PAGE_DIRECTORY
|
||||
|
||||
; Custom page for autostart checkbox
|
||||
Page custom AutostartPage AutostartPageLeave
|
||||
|
||||
!insertmacro MUI_PAGE_INSTFILES
|
||||
|
||||
!insertmacro MUI_PAGE_FINISH
|
||||
@@ -80,8 +85,36 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_LANGUAGE "English"
|
||||
|
||||
; Variables for autostart option
|
||||
Var AutostartCheckbox
|
||||
Var AutostartEnabled
|
||||
|
||||
######################################################################
|
||||
|
||||
; Function to create the autostart options page
|
||||
Function AutostartPage
|
||||
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
|
||||
|
||||
nsDialogs::Create 1018
|
||||
Pop $0
|
||||
|
||||
${If} $0 == error
|
||||
Abort
|
||||
${EndIf}
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||
Pop $AutostartCheckbox
|
||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
|
||||
; Function to handle leaving the autostart page
|
||||
Function AutostartPageLeave
|
||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||
FunctionEnd
|
||||
|
||||
Function GetAppFromCommand
|
||||
Exch $1
|
||||
Push $2
|
||||
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
EnVar::SetHKLM
|
||||
EnVar::AddValueEx "path" "$INSTDIR"
|
||||
|
||||
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
|
||||
# kill ui client
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
# wait the service uninstall take unblock the executable
|
||||
Sleep 3000
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type RuleID string
|
||||
|
||||
func (r RuleID) GetRuleID() string {
|
||||
func (r RuleID) ID() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,11 @@ type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
}
|
||||
|
||||
type protoMatch struct {
|
||||
ips map[string]int
|
||||
policyID []byte
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
@@ -240,12 +245,12 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
||||
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
|
||||
return id.RuleID(addedRule.GetRuleID()), nil
|
||||
return id.RuleID(addedRule.ID()), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
@@ -281,7 +286,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
@@ -289,11 +294,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
@@ -322,14 +327,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -338,18 +343,18 @@ func (d *DefaultManager) addInRules(
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -364,9 +369,8 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) id.RuleID {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
@@ -389,10 +393,8 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
}
|
||||
|
||||
type protoMatch map[mgmProto.RuleProtocol]map[string]int
|
||||
|
||||
in := protoMatch{}
|
||||
out := protoMatch{}
|
||||
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
@@ -405,14 +407,18 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
// 2. Any of rule contains Port.
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||
if drop {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
// store the first encountered PolicyID for this protocol
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
}
|
||||
|
||||
// special case, when we receive this all network IP address
|
||||
@@ -424,7 +430,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
return
|
||||
}
|
||||
|
||||
ipset := protocols[r.Protocol]
|
||||
ipset := protocols[r.Protocol].ips
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
@@ -450,9 +456,10 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
mgmProto.RuleProtocol_UDP,
|
||||
}
|
||||
|
||||
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
|
||||
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||
match, ok := matches[protocol]
|
||||
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
||||
// don't squash if :
|
||||
// 1. Rules not cover all peers in the network
|
||||
// 2. Rules cover only one peer in the network.
|
||||
@@ -465,6 +472,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
Direction: direction,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: protocol,
|
||||
PolicyID: match.policyID,
|
||||
})
|
||||
squashedProtocols[protocol] = struct{}{}
|
||||
|
||||
@@ -493,9 +501,9 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
// if we also have other not squashed rules.
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
||||
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
|
||||
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
|
||||
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -515,7 +523,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
@@ -45,20 +48,20 @@ func TestDefaultManager(t *testing.T) {
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset(nil)
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
@@ -74,7 +77,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
for id := range acl.peerRulesPairs {
|
||||
existedPairs[id.GetRuleID()] = struct{}{}
|
||||
existedPairs[id.ID()] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
@@ -100,7 +103,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
if _, ok := existedPairs[id.GetRuleID()]; ok {
|
||||
if _, ok := existedPairs[id.ID()]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
@@ -339,20 +342,20 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset(nil)
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user