Compare commits

..

101 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
Pascal Fischer
ab2e3fec72 expose resource type consts 2025-03-12 13:49:24 +01:00
Hakan Sariman
47f88f7057 Refactor routeIDLookup methods to use Addr() for resolved IP operations 2025-03-11 19:43:58 +08:00
Hakan Sariman
ee33a6ed7c Refactor RemoveLocalPeerStateRoute to eliminate resourceId parameter 2025-03-11 13:19:30 +08:00
Hakan Sariman
da662cfd08 Add source and destination resource IDs to FlowFields 2025-03-11 13:12:54 +08:00
Hakan Sariman
ed2ee1ee9d Merge branch 'feature/flow' into feat/flow-resid 2025-03-11 13:08:11 +08:00
Viktor Liu
76d73548d6 Fix more conflicts 2025-03-10 18:46:01 +01:00
Viktor Liu
11828a064a Fix conflict 2025-03-10 18:35:32 +01:00
Viktor Liu
0c2a3dd937 Merge branch 'main' into feature/flow 2025-03-10 18:30:45 +01:00
Viktor Liu
47dcf8d68c Fix forwarder IP source/destination (#3463) 2025-03-10 14:55:07 +01:00
Bethuel Mmbaga
cc8f6bcaf3 [management] Fix tests circular dependency (#3460)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-10 15:54:36 +03:00
Hakan Sariman
92286b2541 Implement routeIDLookup for managing local and remote route IDs 2025-03-10 15:58:45 +08:00
Maycon Santos
d8bcf745b0 update integrations 2025-03-09 19:32:38 +01:00
Maycon Santos
8430139d80 fix missing method 2025-03-09 19:03:57 +01:00
Maycon Santos
a2962b4ce0 sync go.sum 2025-03-09 18:50:20 +01:00
Maycon Santos
16fffdb75b sync changes from #3426 2025-03-09 18:48:48 +01:00
Maycon Santos
036cecbf46 update integrations and go mod 2025-03-09 18:47:05 +01:00
Maycon Santos
3482852bb6 sync proto and sum 2025-03-09 18:02:33 +01:00
Maycon Santos
fd62665b1f Merge branch 'main' into feature/flow
# Conflicts:
#	client/cmd/testutil_test.go
#	client/firewall/iptables/router_linux.go
#	client/firewall/nftables/router_linux.go
#	client/firewall/uspfilter/allow_netbird.go
#	client/firewall/uspfilter/allow_netbird_windows.go
#	client/firewall/uspfilter/uspfilter_test.go
#	client/internal/engine.go
#	client/internal/engine_test.go
#	client/server/server_test.go
#	go.mod
#	go.sum
#	management/client/client_test.go
#	management/cmd/management.go
#	management/proto/management.pb.go
#	management/proto/management.proto
#	management/server/account.go
#	management/server/account_test.go
#	management/server/dns_test.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/tools.go
#	management/server/integrations/port_forwarding/controller.go
#	management/server/management_proto_test.go
#	management/server/management_test.go
#	management/server/nameserver_test.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/route_test.go
2025-03-09 17:42:16 +01:00
Hakan Sariman
1ffe48f0d4 Add nil check in CheckRoutes to prevent potential panic 2025-03-08 12:54:33 +03:00
Hakan Sariman
a3b8a21385 Refactor CheckRoutes to return resource IDs for matching source and destination addresses 2025-03-08 12:26:53 +03:00
Hakan Sariman
86492b88c4 Refactor route handling to simplify route information and improve state management 2025-03-08 12:25:35 +03:00
Hakan Sariman
d08a629f9e Merge branch 'feature/flow' into feat/flow-resid 2025-03-08 12:18:02 +03:00
Viktor Liu
36da464413 Fix tracer test 2025-03-07 17:19:10 +01:00
Hakan Sariman
268e3404d3 Merge branch 'feature/flow' into feat/flow-resid 2025-03-07 18:52:11 +03:00
Hakan Sariman
54d0591833 Refactor route handling to use RouteWithResourceId for improved state management 2025-03-07 18:43:49 +03:00
Viktor Liu
86370a0e7b Use bytes for flows event id (#3439) 2025-03-07 16:12:47 +01:00
Viktor Liu
cb16d0f45f Align packet tracer behavior with actual code paths (#3424) 2025-03-07 14:03:45 +01:00
Viktor Liu
e8d8bd8f18 Add peer traffic rule IDs to allowed connections in flows (#3442) 2025-03-07 13:56:26 +01:00
Viktor Liu
8b07f21c28 Don't track intercepted packets (#3448) 2025-03-07 13:56:16 +01:00
Viktor Liu
54be772ffd Handle flow updates (#3455) 2025-03-07 13:56:00 +01:00
Viktor Liu
3c3a454e61 Fix merge regression 2025-03-06 16:54:15 +01:00
Viktor Liu
5ff77b3595 Add flow userspace counters (#3438) 2025-03-06 16:52:56 +01:00
Viktor Liu
b180edbe5c Track icmp with id only (#3447) 2025-03-06 14:51:23 +01:00
Hakan Sariman
de3b5c78d7 Fix nil pointer dereference in CheckRoutes method 2025-03-06 14:10:31 +03:00
Hakan Sariman
0b42f40cf6 Refactor route management to include resource IDs in state handling 2025-03-06 13:51:46 +03:00
Viktor Liu
0a042ac36d Fix merge conflict 2025-03-05 19:11:20 +01:00
Hakan Sariman
e7f921d787 [client] add resource id fields to netflow events 2025-03-05 20:35:52 +03:00
Viktor Liu
e9f11fb11b Replace net.IP with netip.Addr (#3425) 2025-03-05 18:28:05 +01:00
hakansa
419ed275fa Handle TCP RST flag to transition connection state to closed (#3432) 2025-03-05 18:25:42 +01:00
Viktor Liu
2d4fcaf186 Fix proto numbering (#3436) 2025-03-04 16:57:25 +01:00
Viktor Liu
acf172b52c Add kernel conntrack counters (#3434) 2025-03-04 16:46:03 +01:00
Viktor Liu
8c81a823fa Add flow ACL IDs (#3421) 2025-03-04 16:43:07 +01:00
Maycon Santos
619c549547 sync port forwarding 2025-03-04 16:29:59 +01:00
Maycon Santos
9a713a0987 Merge branch 'feature/port-forwarding' into feature/flow
# Conflicts:
#	go.mod
#	go.sum
2025-03-04 16:28:57 +01:00
Pascal Fischer
c4945cd565 add cleanup scheduler + metrics 2025-03-04 16:21:52 +01:00
Viktor Liu
1e10c17ecb Fix tcp state (#3431) 2025-03-04 11:19:54 +01:00
Viktor Liu
96d5190436 Add icmp type and code to forwarder flow event (#3413) 2025-02-28 21:04:07 +01:00
Viktor Liu
d19c26df06 Fix log direction (#3412) 2025-02-28 21:03:40 +01:00
Viktor Liu
36e36414d9 Fix forwarder log displaying (#3411) 2025-02-28 20:53:01 +01:00
bcmmbaga
7e69589e05 Update management-integrations
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:49:56 +00:00
bcmmbaga
aa613ab79a Update golang.org/x/crypto/ssh
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:27:46 +00:00
Viktor Liu
6ead0ff95e Fix log format 2025-02-28 20:24:23 +01:00
Viktor Liu
0db65a8984 Add routed packet drop flow (#3410) 2025-02-28 20:04:59 +01:00
Pascal Fischer
c138807e95 remove log message 2025-02-28 19:54:50 +01:00
Viktor Liu
637c0c8949 Add icmp type and code (#3409) 2025-02-28 19:16:42 +01:00
Viktor Liu
c72e13d8e6 Add conntrack flows (#3406) 2025-02-28 19:16:29 +01:00
Maycon Santos
f6d7bccfa0 Add flow client with sender/receiver (#3405)
add an initial version of receiver client and flow manager receiver and sender
2025-02-28 17:16:18 +00:00
bcmmbaga
e3ed01cafb go mod tidy
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 17:10:44 +00:00
Viktor Liu
fa748a7ec2 Add userspace flow implementation (#3393) 2025-02-28 11:08:35 +01:00
Maycon Santos
cccc615783 update flow proto package generated code 2025-02-28 03:09:09 +00:00
Maycon Santos
2021463ca0 update flow proto package name 2025-02-28 02:51:57 +00:00
Maycon Santos
f48cfd52e9 fix logger stop (#3403)
* fix logger stop

* use context to stop receiver

* update test
2025-02-28 00:28:17 +00:00
Pascal Fischer
6838f53f40 add getPeerByIp store method 2025-02-27 19:01:05 +01:00
Maycon Santos
8276236dfa Add netflow manager (#3398)
* Add netflow manager

* fix linter issues
2025-02-27 12:05:20 +00:00
Viktor Liu
994b923d56 Move proto and rename port and icmp info (#3399) 2025-02-27 12:52:33 +01:00
Viktor Liu
59e2432231 Add event proto fields (#3397) 2025-02-27 12:29:50 +01:00
Pascal Fischer
eee0d123e4 [management] add flow settings and credentials (#3389) 2025-02-27 12:17:07 +01:00
Viktor Liu
e943203ae2 Add event fields (#3390)
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-02-26 12:06:06 +01:00
Pedro Costa
6a775217cf rename flow proto messages 2025-02-25 16:29:54 +00:00
Maycon Santos
175674749f Add memory flow store (#3386) 2025-02-25 15:23:43 +00:00
Pascal Fischer
1e534cecf6 [management] Add flow proto (#3384) 2025-02-25 13:03:27 +01:00
Pedro Costa
aa3aa8c6a8 [management] flow proto 2025-02-25 11:22:54 +00:00
Pascal Fischer
fbdfe45c25 fix merge conflicts on management 2025-02-25 11:57:25 +01:00
Viktor Liu
81ee172db8 Fix route conflict 2025-02-25 11:44:21 +01:00
Viktor Liu
f8fd65a65f Merge branch 'main' into feature/port-forwarding 2025-02-25 11:37:52 +01:00
Bethuel Mmbaga
62b978c050 [management] Add support for tcp/udp allocations (#3381)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 10:11:50 +00:00
Bethuel Mmbaga
4ebf1410c6 [management] Add support to allocate same port for public and internal (#3347)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-21 11:16:24 +03:00
Viktor Liu
630edf2480 Remove unused var 2025-02-20 13:24:37 +01:00
Viktor Liu
ea469d28d7 Merge branch 'main' into feature/port-forwarding 2025-02-20 13:24:05 +01:00
Pascal Fischer
597f1d47b8 fix management test suite 2025-02-20 13:08:18 +01:00
Viktor Liu
fcc96417f9 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:45:30 +01:00
Viktor Liu
8755211a60 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:39:06 +01:00
Pascal Fischer
e6d4653b08 [management] add cloud tag to get ingress ports api spec (#3300)
* fix tag for get endpoint

* update labels
2025-02-12 16:11:54 +01:00
Zoltan Papp
eb69f2de78 Fix nil pointer exception when load empty list and try to cast it (#3282) 2025-02-06 10:28:42 +01:00
Viktor Liu
206420c085 [client] Fix grouping of peer ACLs with different port ranges (#3289) 2025-02-06 10:28:42 +01:00
Christian Stewart
88a864c195 [relay] Use new upstream for nhooyr.io/websocket package (#3287)
The nhooyr.io/websocket package was renamed to github.com/coder/websocket when
the project was transferred to "coder" as the new maintainer.

Use the new import path and update go.mod and go.sum accordingly.

Signed-off-by: Christian Stewart <christian@aperture.us>
2025-02-06 10:28:42 +01:00
Pascal Fischer
a789e9e6d8 [management] fix duplication detection (#3286) 2025-02-05 21:42:09 +01:00
Viktor Liu
9930913e4e Merge branch 'main' into feature/port-forwarding 2025-02-05 18:55:59 +01:00
Viktor Liu
48675f579f Merge branch 'main' into feature/port-forwarding 2025-02-05 17:44:01 +01:00
Pascal Fischer
afec455f86 [management] copy port info (#3283) 2025-02-05 17:30:42 +01:00
Pascal Fischer
035c5d9f23 [management merge only unique entries on network map merge (#3277) 2025-02-05 16:50:45 +01:00
Viktor Liu
b2a5b29fb2 Merge branch 'main' into feature/port-forwarding 2025-02-05 10:15:37 +01:00
Bethuel Mmbaga
9ec61206c2 [management] Add support for filtering peers by name and IP (#3279)
* add peers ip and name filters

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get peers filter

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix get account peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Extend GetAccountPeers store to support filtering by name and IP

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix get peers references

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-05 00:33:15 +03:00
Zoltan Papp
1b011a2d85 [client] Manage the IP forwarding sysctl setting in global way (#3270)
Add new package ipfwdstate that implements reference counting for IP forwarding
state management. This allows multiple usage to safely request IP forwarding
without interfering with each other.
2025-02-03 12:27:18 +01:00
Pascal Fischer
a85ea1ddb0 [manager] ingress ports manager support (#3268)
* add peers manager

* Extend peers manager to support retrieving all peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add network map calc

* move integrations interface

* update management-integrations

* merge main and fix

* go mod tidy

* [management] port forwarding add peer manager fix network map (#3264)

* [management] fix testing tools (#3265)

* Fix net.IPv4 conversion to []byte

* update test to check ipv4

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2025-02-03 09:37:37 +01:00
Zoltán Papp
829e40d2aa Fix ingress manager unnecessary creation 2025-02-01 10:58:47 +01:00
Pascal Fischer
6344e34880 [management] renamed ingress port endpoints (#3263) 2025-02-01 00:40:33 +01:00
Pascal Fischer
a76ca8c565 Merge branch 'main' into feature/port-forwarding 2025-01-29 22:28:10 +01:00
Zoltan Papp
26693e4ea8 Feature/port forwarding client ingress (#3242)
Client-side forward handling

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-01-29 16:04:33 +01:00
Pascal Fischer
f6a71f4193 [management] add openapi specs and generate types for port forwarding proxy (#3236) 2025-01-27 17:47:40 +01:00
454 changed files with 13560 additions and 31341 deletions

View File

@@ -1,27 +0,0 @@
# More info around this file at https://www.git-town.com/configuration-file
[branches]
main = "main"
perennials = []
perennial-regex = ""
[create]
new-branch-type = "feature"
push-new-branches = false
[hosting]
dev-remote = "origin"
# platform = ""
# origin-hostname = ""
[ship]
delete-tracking-branch = false
strategy = "squash-merge"
[sync]
feature-strategy = "merge"
perennial-strategy = "rebase"
prototype-strategy = "merge"
push-hook = true
tags = true
upstream = false

View File

@@ -37,22 +37,17 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following anonymized status output To help us resolve the problem, please attach the following debug output
netbird status -dA netbird status -dA
Create and upload a debug bundle, and share the returned file key: As well as the file created by
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
@@ -62,10 +57,8 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@@ -2,10 +2,6 @@
## Issue ticket number and link ## Issue ticket number and link
## Stack
<!-- branch-stack -->
### Checklist ### Checklist
- [ ] Is it a bug fix - [ ] Is it a bug fix
- [ ] Is a typo/documentation fix - [ ] Is a typo/documentation fix
@@ -13,5 +9,3 @@
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary - [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -0,0 +1,46 @@
name: "Darwin"
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
macos-gotest-
macos-go-
- name: Install libpcap
run: brew install libpcap
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -0,0 +1,46 @@
name: "FreeBSD"
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.1"
prepare: |
pkg install -y go pkgconf xorg
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

593
.github/workflows/golang-test-linux.yml vendored Normal file
View File

@@ -0,0 +1,593 @@
name: Linux
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
name: "Client / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management:
name: "Management / Unit"
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin
- name: Run Shared Sock tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -0,0 +1,72 @@
name: "Windows"
on:
push:
branches:
- main
pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
id: go
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

59
.github/workflows/golangci-lint.yml vendored Normal file
View File

@@ -0,0 +1,59 @@
name: Lint
on: [pull_request]
permissions:
contents: read
pull-requests: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
codespell:
name: codespell
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum
only_warn: 1
golangci:
strategy:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
with:
version: latest
args: --timeout=12m --out-format colored-line-number

View File

@@ -0,0 +1,37 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -0,0 +1,67 @@
name: Mobile
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
android_build:
name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -55,23 +55,16 @@ jobs:
run: go mod tidy run: go mod tidy
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
# - name: Set up QEMU - name: Set up QEMU
# uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v2
# - name: Set up Docker Buildx - name: Set up Docker Buildx
# uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v2
# - name: Login to Docker hub - name: Login to Docker hub
# if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
# uses: docker/login-action@v1 uses: docker/login-action@v1
# with: with:
# username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
# password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
# - name: Log in to the GitHub container registry
# if: github.event_name != 'pull_request'
# uses: docker/login-action@v3
# with:
# registry: ghcr.io
# username: ${{ github.actor }}
# password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -94,25 +87,25 @@ jobs:
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 7 retention-days: 3
- name: upload linux packages - name: upload linux packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 7 retention-days: 3
- name: upload windows packages - name: upload windows packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 7 retention-days: 3
- name: upload macos packages - name: upload macos packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
path: dist/netbird_darwin** path: dist/netbird_darwin**
retention-days: 7 retention-days: 3
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest

22
.github/workflows/sync-main.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: sync main
on:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_main:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'

23
.github/workflows/sync-tag.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: sync tag
on:
push:
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -0,0 +1,308 @@
name: Test Infrastructure files
on:
push:
branches:
- main
pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-docker-compose:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
mysql:
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
env:
MYSQL_USER: netbird
MYSQL_PASSWORD: mysql
MYSQL_ROOT_PASSWORD: mysqlroot
MYSQL_DATABASE: netbird
options: >-
--health-cmd "mysqladmin ping --silent"
--health-interval 10s
--health-timeout 5s
ports:
- 3306:3306
steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
if [ "${{ matrix.store }}" == "mysql" ]; then
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
fi
- name: Install jq
run: sudo apt-get install -y jq
- name: Install curl
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
- name: run configure
working-directory: infrastructure_files
run: bash -x configure.sh
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
CI_NETBIRD_USE_AUTH0: true
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
working-directory: infrastructure_files/artifacts
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
CI_NETBIRD_USE_AUTH0: true
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_AUTH_AUTHORITY: https://example.eu.auth0.com/
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
run: |
set -x
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
run: CGO_ENABLED=0 go build -o netbird-signal main.go
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
docker compose up -d
sleep 5
docker compose ps
docker compose logs --tail=20
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
test-getting-started-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres
run: test -f Caddyfile
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml
- name: test management.json file gen postgres
run: test -f management.json
- name: test turnserver.conf file gen postgres
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen postgres
run: test -f zitadel.env
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test relay.env file gen postgres
run: test -f relay.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- name: test relay.env file gen CockroachDB
run: test -f relay.env

22
.github/workflows/update-docs.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: update docs
on:
push:
tags:
- 'v*'
paths:
- 'management/server/http/api/openapi.yml'
jobs:
trigger_docs_api_update:
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@v1
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -96,20 +96,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
binary: netbird-upload
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries: universal_binaries:
- id: netbird - id: netbird
@@ -146,541 +132,383 @@ nfpms:
scripts: scripts:
postinstall: "release_files/post_install.sh" postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh" preremove: "release_files/pre_remove.sh"
# dockers: dockers:
# - image_templates: - image_templates:
# - netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
# - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64 ids:
# ids: - netbird
# - netbird goarch: amd64
# goarch: amd64 use: buildx
# use: buildx dockerfile: client/Dockerfile
# dockerfile: client/Dockerfile build_flag_templates:
# build_flag_templates: - "--platform=linux/amd64"
# - "--platform=linux/amd64" - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io"
# - "--label=maintainer=dev@netbird.io" - image_templates:
# - image_templates: - netbirdio/netbird:{{ .Version }}-arm64v8
# - netbirdio/netbird:{{ .Version }}-arm64v8 ids:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8 - netbird
# ids: goarch: arm64
# - netbird use: buildx
# goarch: arm64 dockerfile: client/Dockerfile
# use: buildx build_flag_templates:
# dockerfile: client/Dockerfile - "--platform=linux/arm64"
# build_flag_templates: - "--label=org.opencontainers.image.created={{.Date}}"
# - "--platform=linux/arm64" - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - image_templates:
# - "--label=maintainer=dev@netbird.io" - netbirdio/netbird:{{ .Version }}-arm
# - image_templates: ids:
# - netbirdio/netbird:{{ .Version }}-arm - netbird
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm goarch: arm
# ids: goarm: 6
# - netbird use: buildx
# goarch: arm dockerfile: client/Dockerfile
# goarm: 6 build_flag_templates:
# use: buildx - "--platform=linux/arm"
# dockerfile: client/Dockerfile - "--label=org.opencontainers.image.created={{.Date}}"
# build_flag_templates: - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--platform=linux/arm" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
# - image_templates: - image_templates:
# - netbirdio/netbird:{{ .Version }}-rootless-amd64 - netbirdio/netbird:{{ .Version }}-rootless-amd64
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64 ids:
# ids: - netbird
# - netbird goarch: amd64
# goarch: amd64 use: buildx
# use: buildx dockerfile: client/Dockerfile-rootless
# dockerfile: client/Dockerfile-rootless build_flag_templates:
# build_flag_templates: - "--platform=linux/amd64"
# - "--platform=linux/amd64" - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - image_templates:
# - "--label=maintainer=dev@netbird.io" - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - image_templates: ids:
# - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 - netbird
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8 goarch: arm64
# ids: use: buildx
# - netbird dockerfile: client/Dockerfile-rootless
# goarch: arm64 build_flag_templates:
# use: buildx - "--platform=linux/arm64"
# dockerfile: client/Dockerfile-rootless - "--label=org.opencontainers.image.created={{.Date}}"
# build_flag_templates: - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--platform=linux/arm64" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.version={{.Version}}" - image_templates:
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - netbirdio/netbird:{{ .Version }}-rootless-arm
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" ids:
# - "--label=maintainer=dev@netbird.io" - netbird
# - image_templates: goarch: arm
# - netbirdio/netbird:{{ .Version }}-rootless-arm goarm: 6
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm use: buildx
# ids: dockerfile: client/Dockerfile-rootless
# - netbird build_flag_templates:
# goarch: arm - "--platform=linux/arm"
# goarm: 6 - "--label=org.opencontainers.image.created={{.Date}}"
# use: buildx - "--label=org.opencontainers.image.title={{.ProjectName}}"
# dockerfile: client/Dockerfile-rootless - "--label=org.opencontainers.image.version={{.Version}}"
# build_flag_templates: - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--platform=linux/arm" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
# - image_templates: - image_templates:
# - netbirdio/relay:{{ .Version }}-amd64 - netbirdio/relay:{{ .Version }}-amd64
# - ghcr.io/netbirdio/relay:{{ .Version }}-amd64 ids:
# ids: - netbird-relay
# - netbird-relay goarch: amd64
# goarch: amd64 use: buildx
# use: buildx dockerfile: relay/Dockerfile
# dockerfile: relay/Dockerfile build_flag_templates:
# build_flag_templates: - "--platform=linux/amd64"
# - "--platform=linux/amd64" - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=maintainer=dev@netbird.io"
# - "--label=maintainer=dev@netbird.io" - image_templates:
# - image_templates: - netbirdio/relay:{{ .Version }}-arm64v8
# - netbirdio/relay:{{ .Version }}-arm64v8 ids:
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8 - netbird-relay
# ids: goarch: arm64
# - netbird-relay use: buildx
# goarch: arm64 dockerfile: relay/Dockerfile
# use: buildx build_flag_templates:
# dockerfile: relay/Dockerfile - "--platform=linux/arm64"
# build_flag_templates: - "--label=org.opencontainers.image.created={{.Date}}"
# - "--platform=linux/arm64" - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - image_templates:
# - "--label=maintainer=dev@netbird.io" - netbirdio/relay:{{ .Version }}-arm
# - image_templates: ids:
# - netbirdio/relay:{{ .Version }}-arm - netbird-relay
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm goarch: arm
# ids: goarm: 6
# - netbird-relay use: buildx
# goarch: arm dockerfile: relay/Dockerfile
# goarm: 6 build_flag_templates:
# use: buildx - "--platform=linux/arm"
# dockerfile: relay/Dockerfile - "--label=org.opencontainers.image.created={{.Date}}"
# build_flag_templates: - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--platform=linux/arm" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - image_templates:
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - netbirdio/signal:{{ .Version }}-amd64
# - "--label=maintainer=dev@netbird.io" ids:
# - image_templates: - netbird-signal
# - netbirdio/signal:{{ .Version }}-amd64 goarch: amd64
# - ghcr.io/netbirdio/signal:{{ .Version }}-amd64 use: buildx
# ids: dockerfile: signal/Dockerfile
# - netbird-signal build_flag_templates:
# goarch: amd64 - "--platform=linux/amd64"
# use: buildx - "--label=org.opencontainers.image.created={{.Date}}"
# dockerfile: signal/Dockerfile - "--label=org.opencontainers.image.title={{.ProjectName}}"
# build_flag_templates: - "--label=org.opencontainers.image.version={{.Version}}"
# - "--platform=linux/amd64" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.version={{.Version}}" - image_templates:
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - netbirdio/signal:{{ .Version }}-arm64v8
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" ids:
# - "--label=maintainer=dev@netbird.io" - netbird-signal
# - image_templates: goarch: arm64
# - netbirdio/signal:{{ .Version }}-arm64v8 use: buildx
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8 dockerfile: signal/Dockerfile
# ids: build_flag_templates:
# - netbird-signal - "--platform=linux/arm64"
# goarch: arm64 - "--label=org.opencontainers.image.created={{.Date}}"
# use: buildx - "--label=org.opencontainers.image.title={{.ProjectName}}"
# dockerfile: signal/Dockerfile - "--label=org.opencontainers.image.version={{.Version}}"
# build_flag_templates: - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--platform=linux/arm64" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - image_templates:
# - "--label=org.opencontainers.image.version={{.Version}}" - netbirdio/signal:{{ .Version }}-arm
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" ids:
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - netbird-signal
# - "--label=maintainer=dev@netbird.io" goarch: arm
# - image_templates: goarm: 6
# - netbirdio/signal:{{ .Version }}-arm use: buildx
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm dockerfile: signal/Dockerfile
# ids: build_flag_templates:
# - netbird-signal - "--platform=linux/arm"
# goarch: arm - "--label=org.opencontainers.image.created={{.Date}}"
# goarm: 6 - "--label=org.opencontainers.image.title={{.ProjectName}}"
# use: buildx - "--label=org.opencontainers.image.version={{.Version}}"
# dockerfile: signal/Dockerfile - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# build_flag_templates: - "--label=org.opencontainers.image.version={{.Version}}"
# - "--platform=linux/arm" - "--label=maintainer=dev@netbird.io"
# - "--label=org.opencontainers.image.created={{.Date}}" - image_templates:
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - netbirdio/management:{{ .Version }}-amd64
# - "--label=org.opencontainers.image.version={{.Version}}" ids:
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - netbird-mgmt
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" goarch: amd64
# - "--label=maintainer=dev@netbird.io" use: buildx
# - image_templates: dockerfile: management/Dockerfile
# - netbirdio/management:{{ .Version }}-amd64 build_flag_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-amd64 - "--platform=linux/amd64"
# ids: - "--label=org.opencontainers.image.created={{.Date}}"
# - netbird-mgmt - "--label=org.opencontainers.image.title={{.ProjectName}}"
# goarch: amd64 - "--label=org.opencontainers.image.version={{.Version}}"
# use: buildx - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# dockerfile: management/Dockerfile - "--label=org.opencontainers.image.version={{.Version}}"
# build_flag_templates: - "--label=maintainer=dev@netbird.io"
# - "--platform=linux/amd64" - image_templates:
# - "--label=org.opencontainers.image.created={{.Date}}" - netbirdio/management:{{ .Version }}-arm64v8
# - "--label=org.opencontainers.image.title={{.ProjectName}}" ids:
# - "--label=org.opencontainers.image.version={{.Version}}" - netbird-mgmt
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" goarch: arm64
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" use: buildx
# - "--label=maintainer=dev@netbird.io" dockerfile: management/Dockerfile
# - image_templates: build_flag_templates:
# - netbirdio/management:{{ .Version }}-arm64v8 - "--platform=linux/arm64"
# - ghcr.io/netbirdio/management:{{ .Version }}-arm64v8 - "--label=org.opencontainers.image.created={{.Date}}"
# ids: - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - netbird-mgmt - "--label=org.opencontainers.image.version={{.Version}}"
# goarch: arm64 - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# use: buildx - "--label=org.opencontainers.image.version={{.Version}}"
# dockerfile: management/Dockerfile - "--label=maintainer=dev@netbird.io"
# build_flag_templates: - image_templates:
# - "--platform=linux/arm64" - netbirdio/management:{{ .Version }}-arm
# - "--label=org.opencontainers.image.created={{.Date}}" ids:
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - netbird-mgmt
# - "--label=org.opencontainers.image.version={{.Version}}" goarch: arm
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" goarm: 6
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" use: buildx
# - "--label=maintainer=dev@netbird.io" dockerfile: management/Dockerfile
# - image_templates: build_flag_templates:
# - netbirdio/management:{{ .Version }}-arm - "--platform=linux/arm"
# - ghcr.io/netbirdio/management:{{ .Version }}-arm - "--label=org.opencontainers.image.created={{.Date}}"
# ids: - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - netbird-mgmt - "--label=org.opencontainers.image.version={{.Version}}"
# goarch: arm - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# goarm: 6 - "--label=org.opencontainers.image.version={{.Version}}"
# use: buildx - "--label=maintainer=dev@netbird.io"
# dockerfile: management/Dockerfile - image_templates:
# build_flag_templates: - netbirdio/management:{{ .Version }}-debug-amd64
# - "--platform=linux/arm" ids:
# - "--label=org.opencontainers.image.created={{.Date}}" - netbird-mgmt
# - "--label=org.opencontainers.image.title={{.ProjectName}}" goarch: amd64
# - "--label=org.opencontainers.image.version={{.Version}}" use: buildx
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" dockerfile: management/Dockerfile.debug
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" build_flag_templates:
# - "--label=maintainer=dev@netbird.io" - "--platform=linux/amd64"
# - image_templates: - "--label=org.opencontainers.image.created={{.Date}}"
# - netbirdio/management:{{ .Version }}-debug-amd64 - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64 - "--label=org.opencontainers.image.version={{.Version}}"
# ids: - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - netbird-mgmt - "--label=org.opencontainers.image.version={{.Version}}"
# goarch: amd64 - "--label=maintainer=dev@netbird.io"
# use: buildx - image_templates:
# dockerfile: management/Dockerfile.debug - netbirdio/management:{{ .Version }}-debug-arm64v8
# build_flag_templates: ids:
# - "--platform=linux/amd64" - netbird-mgmt
# - "--label=org.opencontainers.image.created={{.Date}}" goarch: arm64
# - "--label=org.opencontainers.image.title={{.ProjectName}}" use: buildx
# - "--label=org.opencontainers.image.version={{.Version}}" dockerfile: management/Dockerfile.debug
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" build_flag_templates:
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--platform=linux/arm64"
# - "--label=maintainer=dev@netbird.io" - "--label=org.opencontainers.image.created={{.Date}}"
# - image_templates: - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - netbirdio/management:{{ .Version }}-debug-arm64v8 - "--label=org.opencontainers.image.version={{.Version}}"
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8 - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# ids: - "--label=org.opencontainers.image.version={{.Version}}"
# - netbird-mgmt - "--label=maintainer=dev@netbird.io"
# goarch: arm64
# use: buildx
# dockerfile: management/Dockerfile.debug
# build_flag_templates:
# - "--platform=linux/arm64"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
# - image_templates: - image_templates:
# - netbirdio/management:{{ .Version }}-debug-arm - netbirdio/management:{{ .Version }}-debug-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm ids:
# ids: - netbird-mgmt
# - netbird-mgmt goarch: arm
# goarch: arm goarm: 6
# goarm: 6 use: buildx
# use: buildx dockerfile: management/Dockerfile.debug
# dockerfile: management/Dockerfile.debug build_flag_templates:
# build_flag_templates: - "--platform=linux/arm"
# - "--platform=linux/arm" - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.created={{.Date}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=maintainer=dev@netbird.io"
# - "--label=maintainer=dev@netbird.io" docker_manifests:
# - image_templates: - name_template: netbirdio/netbird:{{ .Version }}
# - netbirdio/upload:{{ .Version }}-amd64 image_templates:
# - ghcr.io/netbirdio/upload:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-arm64v8
# ids: - netbirdio/netbird:{{ .Version }}-arm
# - netbird-upload - netbirdio/netbird:{{ .Version }}-amd64
# goarch: amd64
# use: buildx
# dockerfile: upload-server/Dockerfile
# build_flag_templates:
# - "--platform=linux/amd64"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
# - image_templates:
# - netbirdio/upload:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
# ids:
# - netbird-upload
# goarch: arm64
# use: buildx
# dockerfile: upload-server/Dockerfile
# build_flag_templates:
# - "--platform=linux/arm64"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
# - image_templates:
# - netbirdio/upload:{{ .Version }}-arm
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm
# ids:
# - netbird-upload
# goarch: arm
# goarm: 6
# use: buildx
# dockerfile: upload-server/Dockerfile
# build_flag_templates:
# - "--platform=linux/arm"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
# docker_manifests:
# - name_template: netbirdio/netbird:{{ .Version }}
# image_templates:
# - netbirdio/netbird:{{ .Version }}-arm64v8
# - netbirdio/netbird:{{ .Version }}-arm
# - netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: netbirdio/netbird:latest
# image_templates:
# - netbirdio/netbird:{{ .Version }}-arm64v8
# - netbirdio/netbird:{{ .Version }}-arm
# - netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: netbirdio/netbird:{{ .Version }}-rootless
# image_templates:
# - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - netbirdio/netbird:{{ .Version }}-rootless-arm
# - netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: netbirdio/netbird:rootless-latest
# image_templates:
# - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - netbirdio/netbird:{{ .Version }}-rootless-arm
# - netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: netbirdio/relay:{{ .Version }}
# image_templates:
# - netbirdio/relay:{{ .Version }}-arm64v8
# - netbirdio/relay:{{ .Version }}-arm
# - netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: netbirdio/relay:latest
# image_templates:
# - netbirdio/relay:{{ .Version }}-arm64v8
# - netbirdio/relay:{{ .Version }}-arm
# - netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: netbirdio/signal:{{ .Version }}
# image_templates:
# - netbirdio/signal:{{ .Version }}-arm64v8
# - netbirdio/signal:{{ .Version }}-arm
# - netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: netbirdio/signal:latest
# image_templates:
# - netbirdio/signal:{{ .Version }}-arm64v8
# - netbirdio/signal:{{ .Version }}-arm
# - netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: netbirdio/management:{{ .Version }}
# image_templates:
# - netbirdio/management:{{ .Version }}-arm64v8
# - netbirdio/management:{{ .Version }}-arm
# - netbirdio/management:{{ .Version }}-amd64
#
# - name_template: netbirdio/management:latest
# image_templates:
# - netbirdio/management:{{ .Version }}-arm64v8
# - netbirdio/management:{{ .Version }}-arm
# - netbirdio/management:{{ .Version }}-amd64
#
# - name_template: netbirdio/management:debug-latest
# image_templates:
# - netbirdio/management:{{ .Version }}-debug-arm64v8
# - netbirdio/management:{{ .Version }}-debug-arm
# - netbirdio/management:{{ .Version }}-debug-amd64
# - name_template: netbirdio/upload:{{ .Version }}
# image_templates:
# - netbirdio/upload:{{ .Version }}-arm64v8
# - netbirdio/upload:{{ .Version }}-arm
# - netbirdio/upload:{{ .Version }}-amd64
#
# - name_template: netbirdio/upload:latest
# image_templates:
# - netbirdio/upload:{{ .Version }}-arm64v8
# - netbirdio/upload:{{ .Version }}-arm
# - netbirdio/upload:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:latest
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:rootless-latest
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: ghcr.io/netbirdio/relay:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm
# - ghcr.io/netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/relay:latest
# image_templates:
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm
# - ghcr.io/netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/signal:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm
# - ghcr.io/netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/signal:latest
# image_templates:
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm
# - ghcr.io/netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/management:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/management:{{ .Version }}-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/management:latest
# image_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/management:{{ .Version }}-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/management:debug-latest
# image_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
#
# - name_template: ghcr.io/netbirdio/upload:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm
# - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/upload:latest
# image_templates:
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm
# - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
# brews:
# - ids:
# - default
# repository:
# owner: netbirdio
# name: homebrew-tap
# token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
# commit_author:
# name: Netbird
# email: dev@netbird.io
# description: Netbird project.
# download_strategy: CurlDownloadStrategy
# homepage: https://netbird.io/
# license: "BSD3"
# test: |
# system "#{bin}/{{ .ProjectName }} version"
# uploads: - name_template: netbirdio/netbird:latest
# - name: debian image_templates:
# ids: - netbirdio/netbird:{{ .Version }}-arm64v8
# - netbird-deb - netbirdio/netbird:{{ .Version }}-arm
# mode: archive - netbirdio/netbird:{{ .Version }}-amd64
# target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
# username: dev@wiretrustee.com
# method: PUT
# - name: yum - name_template: netbirdio/netbird:{{ .Version }}-rootless
# ids: image_templates:
# - netbird-rpm - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# mode: archive - netbirdio/netbird:{{ .Version }}-rootless-arm
# target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} - netbirdio/netbird:{{ .Version }}-rootless-amd64
# username: dev@wiretrustee.com
# method: PUT - name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/signal:latest
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/management:{{ .Version }}
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:latest
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:debug-latest
image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
brews:
- ids:
- default
repository:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
commit_author:
name: Netbird
email: dev@netbird.io
description: Netbird project.
download_strategy: CurlDownloadStrategy
homepage: https://netbird.io/
license: "BSD3"
test: |
system "#{bin}/{{ .ProjectName }} version"
uploads:
- name: debian
ids:
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
method: PUT
- name: yum
ids:
- netbird-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com
method: PUT
checksum: checksum:
extra_files: extra_files:

View File

@@ -79,19 +79,19 @@ nfpms:
dependencies: dependencies:
- netbird - netbird
# uploads: uploads:
# - name: debian - name: debian
# ids: ids:
# - netbird-ui-deb - netbird-ui-deb
# mode: archive mode: archive
# target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
# username: dev@wiretrustee.com username: dev@wiretrustee.com
# method: PUT method: PUT
# - name: yum - name: yum
# ids: ids:
# - netbird-ui-rpm - netbird-ui-rpm
# mode: archive mode: archive
# target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
# username: dev@wiretrustee.com username: dev@wiretrustee.com
# method: PUT method: PUT

View File

@@ -1,64 +1,148 @@
## Contributor License Agreement # Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual We are incredibly thankful for the contributions we receive from the community.
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany, We require our external contributors to sign a Contributor License Agreement ("CLA") in
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions order to ensure that our projects remain licensed under Free and Open Source licenses such
under which NetBird may utilize software contributions provided by the Contributor for inclusion in as BSD-3 while allowing NetBird to build a sustainable business.
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
of the terms and conditions outlined below. The Contributor further represents that they are authorized to NetBird is committed to having a true Open Source Software ("OSS") license for
complete this process as described herein. our software. A CLA enables NetBird to safely commercialize our products
while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project.
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
# Human-friendly summary
This is a human-readable summary of (and not a substitute for) the full agreement (below).
This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
in commercial products.
</li>
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
license to use that patent including within commercial products. You also agree that you
have permission to grant this license.
</li>
<li>No Warranty or Support Obligations.
By making a contribution, you are not obligating yourself to provide support for the
contribution, and you are not taking on any warranty obligations or providing any
assurances about how it will perform.
</li>
The CLA does not change the terms of the standard open source license used by our software
such as BSD-3 or MIT.
You are still free to use our projects within your own projects or businesses, republish
modified source, and more.
Please reference the appropriate license for the project you're contributing to to learn
more.
# Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
under the project's respective open source license, such as BSD-3.
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
# Signing the CLA
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
to sign the CLA if you haven't already.
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
require you to sign again.
# Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
You reserve all right, title, and interest in and to Your Contributions.
1. Definitions.
```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
```
```
"Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
or documentation of, any of the products owned or managed by NetBird (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of,
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution."
```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works.
## 1 Preamble 3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
must have a contributor license agreement on file to be signed by each Contributor, containing the license irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
terms below. This license serves as protection for both the Contributor as well as NetBird and its software users; and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
it does not change Contributors rights to use his/her own Contributions for any other purpose. necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
## 2 Definitions
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work. 4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to
intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
with NetBird.
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution". 5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
others). You represent that Your Contribution submissions include complete details of any third-party license or
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
and which are associated with any part of Your Contributions.
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
## 3 Licenses 6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees. You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributors Contribution(s) alone or by combination of Contributors Contribution(s) with the Work to which such Contribution(s) was Submitted.
3.3 NetBird hereby accepts such licenses. 7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
## 4 Contributors Representations
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributors employer has IP Rights to Contributors Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
4.2 Contributor represents that any Contribution is his/her original creation.
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
## 5 Information obligation
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
## 6 Submission of Third-Party works
Should Contributor wish to submit work that is not Contributors original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
## 7 No Consideration
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
## 8 Final Provisions
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://docs.netbird.io/slack-url"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,13 +29,13 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
<br/> <br/>
</strong> </strong>
<br> <br>
<a href="https://github.com/netbirdio/kubernetes-operator"> <a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
New: NetBird Kubernetes Operator Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
</a> </a>
</p> </p>
@@ -57,16 +57,16 @@
### Key features ### Key features
| Connectivity | Management | Security | Automation| Platforms | | Connectivity | Management | Security | Automation | Platforms |
|----|----|----|----|----| |------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> | | <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> | | <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> | | <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> | | <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> | | <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> | | | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> | | | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
||||| <ul><li>- \[x] Docker</ui></li> | | | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud

View File

@@ -1,9 +1,5 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache ca-certificates iptables ip6tables
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,7 +1,6 @@
FROM alpine:3.21.0 FROM alpine:3.21.0
ARG NETBIRD_BINARY=netbird COPY netbird /usr/local/bin/netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \ RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird && adduser -D -h /var/lib/netbird netbird

View File

@@ -59,8 +59,6 @@ type Client struct {
deviceName string deviceName string
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -108,8 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -134,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -176,53 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: routes[0].Network.String(),
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@@ -1,27 +0,0 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,23 +7,30 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum ConnStatus string // Todo replace to enum
} }
// PeerInfoArray is a wrapper of []PeerInfo // PeerInfoCollection made for Java layer to get non default types as collection
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct { type PeerInfoArray struct {
items []PeerInfo items []PeerInfo
} }
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray { func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
array.items = append(array.items, s) array.items = append(array.items, s)
return array return array
} }
// Get return an element of the collection // Get return an element of the collection
func (array *PeerInfoArray) Get(i int) *PeerInfo { func (array PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i] return &array.items[i]
} }
// Size return with the size of the collection // Size return with the size of the collection
func (array *PeerInfoArray) Size() int { func (array PeerInfoArray) Size() int {
return len(array.items) return len(array.items)
} }

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences exports a subset of the internal config for gomobile // Preferences export a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences creates a new Preferences instance // NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL reads URL from config file // GetManagementURL read url from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL stores the given URL and waits for commit // SetManagementURL store the given url and wait for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL reads URL from config file // GetAdminURL read url from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL stores the given URL and waits for commit // SetAdminURL store the given url and wait for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey reads pre-shared key from config file // GetPreSharedKey read preshared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@@ -66,160 +66,12 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey stores the given key and waits for commit // SetPreSharedKey store the given key and wait for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled stores whether Rosenpass is enabled // Commit write out the changes into config file
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@@ -26,7 +26,7 @@ type Anonymizer struct {
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0, 100:: // 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
} }
@@ -69,22 +69,6 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip] return a.ipAnonymizer[ip]
} }
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs // isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -11,12 +11,9 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
const errCloseConnection = "Failed to close connection: %v" const errCloseConnection = "Failed to close connection: %v"
@@ -87,27 +84,16 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}() }()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag), Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
} })
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" { cmd.Println(resp.GetPath())
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil return nil
} }
@@ -222,19 +208,23 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput, Status: statusOutput,
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
} })
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -249,15 +239,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel()) cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Printf("Local file:\n%s\n", resp.GetPath()) cmd.Println(resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil return nil
} }
@@ -344,34 +326,3 @@ func formatDuration(d time.Duration) string {
s := d / time.Second s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s) return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
} }
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
if connectClient != nil {
networkMap, err = connectClient.GetLatestNetworkMap()
if err != nil {
log.Warnf("Failed to get latest network map: %v", err)
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: config,
StatusRecorder: recorder,
NetworkMap: networkMap,
LogFile: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
log.Errorf("Failed to generate debug bundle: %v", err)
return
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}

View File

@@ -1,39 +0,0 @@
//go:build unix
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
usr1Ch := make(chan os.Signal, 1)
signal.Notify(usr1Ch, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-usr1Ch:
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
go generateDebugBundle(config, recorder, connectClient, logFilePath)
}
}
}()
}

View File

@@ -1,126 +0,0 @@
package cmd
import (
"context"
"errors"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
waitTimeout = 5 * time.Second
)
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
// Example usage with PowerShell:
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
// $evt.Set()
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
env := os.Getenv(envListenEvent)
if env == "" {
return
}
listenEvent, err := strconv.ParseBool(env)
if err != nil {
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
return
}
if !listenEvent {
return
}
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
return
}
// TODO: restrict access by ACL
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
} else {
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
return
}
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
}
func waitForEvent(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
eventHandle windows.Handle,
) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on debug event. Triggering debug bundle generation.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
}
go generateDebugBundle(config, recorder, connectClient, logFilePath)
case uint32(windows.WAIT_TIMEOUT):
default:
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
case <-ctx.Done():
return
}
}
}
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@@ -20,10 +19,6 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the Netbird Management Service (first run)", Short: "login to the Netbird Management Service (first run)",
@@ -56,9 +51,6 @@ var loginCmd = &cobra.Command{
return err return err
} }
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
@@ -99,11 +91,11 @@ var loginCmd = &cobra.Command{
} }
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -135,7 +127,7 @@ var loginCmd = &cobra.Command{
} }
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -196,7 +188,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -206,7 +198,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout) waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
@@ -220,34 +212,23 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
} }
if noBrowser { cmd.Println("Please do the SSO login in your browser. \n" +
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg) "If your browser didn't open automatically, use this URL to log in:\n\n" +
} else { verificationURIComplete + " " + codeMsg)
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
}
cmd.Println("") cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
if !noBrowser { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
if err := open.Run(verificationURIComplete); err != nil { "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
} }
} }
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isLinuxRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@@ -22,26 +22,23 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
) )
const ( const (
externalIPMapFlag = "external-ip-map" externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address" dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass" enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive" rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key" preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name" interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port" wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor" networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
enableLazyConnectionFlag = "enable-lazy-connection" blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -77,9 +74,7 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
debugUploadBundle bool blockLANAccess bool
debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -184,11 +179,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -28,19 +27,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
config := &service.Config{ return &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "Netbird mesh network client", Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string),
} }
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
} }
func newSVC(prg *program, conf *service.Config) (service.Service, error) { func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -16,17 +16,12 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint log.Info("starting Netbird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()
@@ -120,7 +115,6 @@ var runCmd = &cobra.Command{
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
} }
if logFile != "" { if logFile != "console" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }

View File

@@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -69,10 +69,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
status := resp.GetStatus() if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+ cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+ " netbird up \n\n"+
@@ -130,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected": case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
default: default:
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter) return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
} }
if len(ipsFilter) > 0 { if len(ipsFilter) > 0 {

View File

@@ -6,8 +6,6 @@ const (
disableServerRoutesFlag = "disable-server-routes" disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns" disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
) )
var ( var (
@@ -15,8 +13,6 @@ var (
disableServerRoutes bool disableServerRoutes bool
disableDNS bool disableDNS bool
disableFirewall bool disableFirewall bool
blockLANAccess bool
blockInbound bool
) )
func init() { func init() {
@@ -32,11 +28,4 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.") "Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
} }

View File

@@ -6,17 +6,14 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -34,7 +31,7 @@ import (
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
t.Helper() t.Helper()
config := &types.Config{} config := &mgmt.Config{}
_, err := util.ReadJson("../testdata/management.json", config) _, err := util.ReadJson("../testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -69,7 +66,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
@@ -92,24 +89,14 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
permissionsManagerMock := permissions.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
} }
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages { for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil { if stage.ForwardingDetails != nil {

View File

@@ -32,16 +32,12 @@ const (
const ( const (
dnsLabelsFlag = "extra-dns-labels" dnsLabelsFlag = "extra-dns-labels"
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
) )
var ( var (
foregroundMode bool foregroundMode bool
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
@@ -55,11 +51,12 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil, upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+ `Sets DNS labels`+
@@ -68,9 +65,6 @@ func init() {
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+ `E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`, `or --extra-dns-labels ""`,
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -118,124 +112,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
ic, err := setupConfig(customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup config: %v", err)
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(*ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
cmd.Println("Already connected")
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return fmt.Errorf("get setup key: %v", err)
}
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
@@ -260,7 +136,7 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*interna
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return nil, err return err
} }
ic.InterfaceName = &interfaceName ic.InterfaceName = &interfaceName
} }
@@ -311,29 +187,83 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*interna
ic.BlockLANAccess = &blockLANAccess ic.BlockLANAccess = &blockLANAccess
} }
if cmd.Flag(blockInboundFlag).Changed { providedSetupKey, err := getSetupKey()
ic.BlockInbound = &blockInbound if err != nil {
return err
} }
if cmd.Flag(enableLazyConnectionFlag).Changed { config, err := internal.UpdateOrCreateConfig(ic)
ic.LazyConnectionEnabled = &lazyConnEnabled if err != nil {
return fmt.Errorf("get config file: %v", err)
} }
return &ic, nil
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run(nil)
} }
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
cmd.Println("Already connected")
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsUnixDesktopClient: isUnixRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0, CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -358,7 +288,7 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return nil, err return err
} }
loginRequest.InterfaceName = &interfaceName loginRequest.InterfaceName = &interfaceName
} }
@@ -393,14 +323,45 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
if cmd.Flag(blockInboundFlag).Changed { var loginErr error
loginRequest.BlockInbound = &blockInbound
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
} }
if cmd.Flag(enableLazyConnectionFlag).Changed { if loginErr != nil {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled return fmt.Errorf("login failed: %v", loginErr)
} }
return &loginRequest, nil
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
} }
func validateNATExternalIPs(list []string) error { func validateNATExternalIPs(list []string) error {

View File

@@ -113,16 +113,17 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination firewall.Network, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
sPort, dPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -147,10 +148,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -202,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
firewall.ProtocolALL, "all",
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
@@ -223,16 +220,10 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }
@@ -252,14 +243,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net/netip" "net"
"testing" "testing"
"time" "time"
@@ -19,8 +19,11 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"), Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -67,12 +70,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -92,9 +95,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := netip.MustParseAddr("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -116,8 +119,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"), Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -138,11 +144,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -180,8 +186,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"), Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -203,11 +212,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := netip.MustParseAddr("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -38,12 +38,10 @@ const (
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre" jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre" jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post" jumpNatPost = "jump-nat-post"
markManglePre = "mark-mangle-pre" matchSet = "--match-set"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
@@ -57,18 +55,18 @@ type ruleInfo struct {
} }
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Source firewall.Network Sources []netip.Prefix
Destination firewall.Network Destination netip.Prefix
Proto firewall.Protocol Proto firewall.Protocol
SPort *firewall.Port SPort *firewall.Port
DPort *firewall.Port DPort *firewall.Port
Direction firewall.RuleDirection Direction firewall.RuleDirection
Action firewall.Action Action firewall.Action
SetName string
} }
type routeRules map[string][]string type routeRules map[string][]string
// the ipset library currently does not support comments, so we use the name only (string)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct { type router struct {
@@ -117,10 +115,6 @@ func (r *router) init(stateManager *statemanager.Manager) error {
return fmt.Errorf("create containers: %w", err) return fmt.Errorf("create containers: %w", err)
} }
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.updateState() r.updateState()
return nil return nil
@@ -129,7 +123,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination firewall.Network, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -140,28 +134,27 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil return ruleKey, nil
} }
var source firewall.Network var setName string
if len(sources) > 1 { if len(sources) > 1 {
source.Set = firewall.NewPrefixSet(sources) setName = firewall.GenerateSetName(sources)
} else if len(sources) > 0 { if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
source.Prefix = sources[0] return nil, fmt.Errorf("create or get ipset: %w", err)
}
} }
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Source: source, Sources: sources,
Destination: destination, Destination: destination,
Proto: proto, Proto: proto,
SPort: sPort, SPort: sPort,
DPort: dPort, DPort: dPort,
Action: action, Action: action,
SetName: setName,
} }
rule, err := r.genRouteRuleSpec(params, sources) rule := genRouteFilteringRuleSpec(params)
if err != nil {
return nil, fmt.Errorf("generate route rule spec: %w", err)
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end // Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
// after the established rule // after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
@@ -184,13 +177,17 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil { if setName != "" {
return fmt.Errorf("decrement ipset counter: %w", err) if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
} }
} else { } else {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@@ -201,26 +198,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil return nil
} }
func (r *router) decrementSetCounter(rule []string) error { func (r *router) findSetNameInRule(rule []string) string {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule []string) []string {
var sets []string
for i, arg := range rule { for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3]) return rule[i+3]
} }
} }
return sets return ""
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
@@ -241,13 +225,15 @@ func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil { if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err) return fmt.Errorf("destroy set %s: %w", setName, err)
} }
log.Debugf("Deleted unused ipset %s", setName)
return nil return nil
} }
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -274,14 +260,16 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if pair.Masquerade { if err := r.ipFwdState.ReleaseForwarding(); err != nil {
if err := r.removeNatRule(pair); err != nil { log.Errorf("%v", err)
return fmt.Errorf("remove nat rule: %w", err) }
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -319,10 +307,8 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
if err := r.decrementSetCounter(rule); err != nil { log.Debugf("legacy forwarding rule %s not found", ruleKey)
return fmt.Errorf("decrement ipset counter: %w", err)
}
} }
return nil return nil
@@ -362,16 +348,12 @@ func (r *router) Reset() error {
if err := r.cleanUpDefaultForwardRules(); err != nil { if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil { if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
r.updateState() r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
@@ -441,57 +423,6 @@ func (r *router) createContainers() error {
return nil return nil
} }
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) addPostroutingRules() error { func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade // First rule for outbound masquerade
rule1 := []string{ rule1 := []string{
@@ -533,7 +464,7 @@ func (r *router) insertEstablishedRule(chain string) error {
} }
func (r *router) addJumpRules() error { func (r *router) addJumpRules() error {
// Jump to nat chain // Jump to NAT chain
natRule := []string{"-j", chainRTNAT} natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err) return fmt.Errorf("add nat postrouting jump rule: %v", err)
@@ -607,26 +538,12 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
rule = append(rule, rule = append(rule,
"-m", "conntrack", "-m", "conntrack",
"--ctstate", "NEW", "--ctstate", "NEW",
) "-s", pair.Source.String(),
sourceExp, err := r.applyNetwork("-s", pair.Source, nil) "-d", pair.Destination.String(),
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
) )
// Ensure nat rules come first, so the mark can be overwritten. if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
@@ -644,10 +561,6 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else { } else {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
@@ -813,21 +726,17 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string var rule []string
sourceExp, err := r.applyNetwork("-s", params.Source, sources) if params.SetName != "" {
if err != nil { rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
return nil, fmt.Errorf("apply network -s: %w", err) } else if len(params.Sources) > 0 {
source := params.Sources[0]
} rule = append(rule, "-s", source.String())
destExp, err := r.applyNetwork("-d", params.Destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
} }
rule = append(rule, sourceExp...) rule = append(rule, "-d", params.Destination.String())
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
@@ -837,47 +746,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, "-j", actionToStr(params.Action)) rule = append(rule, "-j", actionToStr(params.Action))
return rule, nil return rule
}
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
} }
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {

View File

@@ -46,9 +46,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 5. jump rule to PRE nat chain // 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule // 6. static outbound masquerade rule
// 7. static return masquerade rule // 7. static return masquerade rule
// 8. mangle prerouting mark rule require.Len(t, manager.rules, 7, "should have created rules map")
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -60,8 +58,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")}, Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true, Masquerade: true,
} }
@@ -332,7 +330,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
@@ -347,29 +345,23 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content // Verify rule content
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Source: source, Sources: tt.sources,
Destination: firewall.Network{Prefix: tt.destination}, Destination: tt.destination,
Proto: tt.proto, Proto: tt.proto,
SPort: tt.sPort, SPort: tt.sPort,
DPort: tt.dPort, DPort: tt.dPort,
Action: tt.action, Action: tt.action,
SetName: "",
} }
expectedRule, err := r.genRouteRuleSpec(params, nil) expectedRule := genRouteFilteringRuleSpec(params)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet { if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName() setName := firewall.GenerateSetName(tt.sources)
expectedRule, err = r.genRouteRuleSpec(params, nil) params.SetName = setName
require.NoError(t, err, "Failed to generate expected rule spec with set") expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created // Check if the set was created
_, exists := r.ipsetCounter.Get(setName) _, exists := r.ipsetCounter.Get(setName)
@@ -384,62 +376,3 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
} }
} }
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
expected []string
}{
{
name: "Basic rule with two sets",
rule: []string{
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
},
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
},
{
name: "No sets",
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
expected: []string{},
},
{
name: "Multiple sets with different positions",
rule: []string{
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
},
expected: []string{"set1", "set-abc123"},
},
{
name: "Boundary case - sequence appears at end",
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
expected: []string{"final-set"},
},
{
name: "Incomplete pattern - missing set name",
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
expected: []string{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
return
}
for i, set := range result {
if set != tc.expected[i] {
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
}
}
})
}
}

View File

@@ -1,10 +1,13 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sort" "sort"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -40,18 +43,6 @@ const (
// Action is the action to be taken on a rule // Action is the action to be taken on a rule
type Action int type Action int
// String returns the string representation of the action
func (a Action) String() string {
switch a {
case ActionAccept:
return "accept"
case ActionDrop:
return "drop"
default:
return "unknown"
}
}
const ( const (
// ActionAccept is the action to accept a packet // ActionAccept is the action to accept a packet
ActionAccept Action = iota ActionAccept Action = iota
@@ -59,33 +50,6 @@ const (
ActionDrop ActionDrop
) )
// Network is a rule destination, either a set or a prefix
type Network struct {
Set Set
Prefix netip.Prefix
}
// String returns the string representation of the destination
func (d Network) String() string {
if d.Prefix.IsValid() {
return d.Prefix.String()
}
if d.IsSet() {
return d.Set.HashedName()
}
return "<invalid network>"
}
// IsSet returns true if the destination is a set
func (d Network) IsSet() bool {
return d.Set != Set{}
}
// IsPrefix returns true if the destination is a valid prefix
func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// Manager is the high level abstraction of a firewall manager // Manager is the high level abstraction of a firewall manager
// //
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
@@ -116,14 +80,13 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination Network, destination netip.Prefix,
proto Protocol, proto Protocol,
sPort, dPort *Port, sPort *Port,
dPort *Port,
action Action, action Action,
) (Rule, error) ) (Rule, error)
@@ -156,9 +119,6 @@ type Manager interface {
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {
@@ -193,6 +153,22 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
return nil return nil
} }
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 { if len(prefixes) == 0 {

View File

@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.NewPrefixSet(prefixes1) result1 := manager.GenerateSetName(prefixes1)
result2 := manager.NewPrefixSet(prefixes2) result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("10.0.0.0/8"),
} }
result := manager.NewPrefixSet(prefixes) result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName()) matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil { if err != nil {
t.Fatalf("Error matching regex: %v", err) t.Fatalf("Error matching regex: %v", err)
} }
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
}) })
t.Run("Empty input produces consistent result", func(t *testing.T) { t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.NewPrefixSet([]netip.Prefix{}) result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.NewPrefixSet([]netip.Prefix{}) result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 { if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.NewPrefixSet(prefixes1) result1 := manager.GenerateSetName(prefixes1)
result2 := manager.NewPrefixSet(prefixes2) result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)

View File

@@ -1,13 +1,15 @@
package manager package manager
import ( import (
"net/netip"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type RouterPair struct { type RouterPair struct {
ID route.ID ID route.ID
Source Network Source netip.Prefix
Destination Network Destination netip.Prefix
Masquerade bool Masquerade bool
Inverse bool Inverse bool
} }

View File

@@ -1,74 +0,0 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
type Set struct {
hash [4]byte
comment string
}
// String returns the string representation of the set: hashed name and comment
func (h Set) String() string {
if h.comment == "" {
return h.HashedName()
}
return h.HashedName() + ": " + h.comment
}
// HashedName returns the string representation of the hash
func (h Set) HashedName() string {
return fmt.Sprintf(
"nb-%s",
hex.EncodeToString(h.hash[:]),
)
}
// Comment returns the comment of the set
func (h Set) Comment() string {
return h.comment
}
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()
for _, src := range prefixes {
bytes, err := src.MarshalBinary()
if err != nil {
log.Warnf("failed to marshal prefix %s: %v", src, err)
}
hash.Write(bytes)
}
var set Set
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}
// NewDomainSet generates a unique name for an ipset based on the given domains.
func NewDomainSet(domains domain.List) Set {
slices.Sort(domains)
hash := sha256.New()
for _, d := range domains {
hash.Write([]byte(d.PunycodeString()))
}
set := Set{
comment: domains.SafeString(),
}
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}

View File

@@ -25,10 +25,9 @@ const (
chainNameInputRules = "netbird-acl-input-rules" chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting" chainNamePrerouting = "netbird-rt-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
@@ -463,15 +462,13 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP. // netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
// Chain is created by route manager m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
// TODO: move creation to a common place Name: chainNamePrerouting,
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable, Table: m.workTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
} })
m.addFwmarkToForward(chainFwFilter) m.addFwmarkToForward(chainFwFilter)

View File

@@ -135,16 +135,17 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination firewall.Network, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
sPort, dPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -170,10 +171,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -245,7 +242,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Close closes the firewall manager // Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -328,16 +325,10 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }
@@ -368,14 +359,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -3,6 +3,7 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@@ -24,8 +25,11 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -66,11 +70,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := netip.MustParseAddr("100.96.0.1").Unmap() ip := net.ParseIP("100.96.0.1")
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -105,6 +109,8 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -126,7 +132,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: ip.AsSlice(), Data: add.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -167,8 +173,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -188,11 +197,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := netip.MustParseAddr("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -273,14 +282,14 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := netip.MustParseAddr("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil, nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")}, netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
&fw.Port{Values: []uint16{443}}, &fw.Port{Values: []uint16{443}},
@@ -289,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, err, "failed to add route filtering rule") require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{ pair := fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")}, Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true, Masquerade: true,
} }
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)

View File

@@ -10,6 +10,7 @@ import (
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
@@ -43,14 +44,9 @@ const (
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
var ( var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found") errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
) )
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct { type router struct {
conn *nftables.Conn conn *nftables.Conn
workTable *nftables.Table workTable *nftables.Table
@@ -58,7 +54,7 @@ type router struct {
chains map[string]*nftables.Chain chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@@ -104,10 +100,6 @@ func (r *router) init(workTable *nftables.Table) error {
return fmt.Errorf("create containers: %w", err) return fmt.Errorf("create containers: %w", err)
} }
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
return nil return nil
} }
@@ -167,7 +159,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to list tables: %v", err) return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
} }
for _, table := range tables { for _, table := range tables {
@@ -204,21 +196,15 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{ // Chain is created by acl manager
Name: chainNameManglePostrouting, // TODO: move creation to a common place
Table: r.workTable, r.chains[chainNamePrerouting] = &nftables.Chain{
Hooknum: nftables.ChainHookPostrouting, Name: chainNamePrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable, Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
}) }
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil { if err := r.addPostroutingRules(); err != nil {
@@ -234,83 +220,7 @@ func (r *router) createContainers() error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err) return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
} }
return nil return nil
@@ -320,7 +230,7 @@ func (r *router) setupDataPlaneMark() error {
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination firewall.Network, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -335,29 +245,23 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw] chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any var exprs []expr.Any
var source firewall.Network
switch { switch {
case len(sources) == 1 && sources[0].Bits() == 0: case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching // If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1: case len(sources) == 1:
// If there's only one source, we can use it directly // If there's only one source, we can use it directly
source.Prefix = sources[0] exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default: default:
// If there are multiple sources, use a set // If there are multiple sources, create or get an ipset
source.Set = firewall.NewPrefixSet(sources) var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
} }
sourceExp, err := r.applyNetwork(source, sources, true) // Handle destination
if err != nil { exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle protocol // Handle protocol
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
@@ -401,27 +305,39 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule) rule = r.conn.AddRule(rule)
} }
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf(flushError, err)
} }
r.rules[string(ruleKey)] = rule r.rules[string(ruleKey)] = rule
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil return ruleKey, nil
} }
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) { func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{ setName := firewall.GenerateSetName(sources)
set: set, ref, err := r.ipsetCounter.Increment(setName, sources)
prefixes: prefixes,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err) return nil, fmt.Errorf("create or get ipset for sources: %w", err)
} }
return getIpSetExprs(ref, isSource) exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -440,54 +356,42 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey) return fmt.Errorf("route rule %s has no handle", ruleKey)
} }
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil { if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err) return fmt.Errorf("delete: %w", err)
} }
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil return nil
} }
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) { func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them // overlapping prefixes will result in an error, so we need to merge them
prefixes := firewall.MergeIPRanges(input.prefixes) sources = firewall.MergeIPRanges(sources)
nfset := &nftables.Set{ set := &nftables.Set{
Name: setName, Name: setName,
Comment: input.set.Comment(), Table: r.workTable,
Table: r.workTable,
// required for prefixes // required for prefixes
Interval: true, Interval: true,
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
} }
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement var elements []nftables.SetElement
for _, prefix := range prefixes { for _, prefix := range sources {
// TODO: Implement IPv6 support // TODO: Implement IPv6 support
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue continue
} }
@@ -503,7 +407,18 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
) )
} }
return elements
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
} }
// calculateLastIP determines the last IP in a given prefix. // calculateLastIP determines the last IP in a given prefix.
@@ -527,8 +442,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b return b
} }
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error { func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(nfset) r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@@ -537,27 +452,13 @@ func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
return nil return nil
} }
func (r *router) decrementSetCounter(rule *nftables.Rule) error { func (r *router) findSetNameInRule(rule *nftables.Rule) string {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs { for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok { if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName) return lookup.SetName
} }
} }
return sets return ""
} }
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@@ -573,6 +474,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -595,8 +500,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
} }
return nil return nil
@@ -604,15 +508,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue // addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error { func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
if err != nil { destExp := generateCIDRMatcherExpressions(false, pair.Destination)
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq op := expr.CmpOpEq
if pair.Inverse { if pair.Inverse {
@@ -620,6 +517,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
exprs := []expr.Any{ exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyIIFNAME, Key: expr.MetaKeyIIFNAME,
Register: 1, Register: 1,
@@ -630,9 +547,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Data: ifname(r.wgIface.Name()), Data: ifname(r.wgIface.Name()),
}, },
} }
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
@@ -662,11 +576,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
} }
// Ensure nat rules come first, so the mark can be overwritten. r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting], Chain: r.chains[chainNamePrerouting],
Exprs: exprs, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
@@ -747,15 +659,8 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
if err != nil { destExp := generateCIDRMatcherExpressions(false, pair.Destination)
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Counter{}, &expr.Counter{},
@@ -764,8 +669,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}, },
} }
exprs = append(exprs, sourceExp...) expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@@ -778,7 +682,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingFw], Chain: r.chains[chainNameRoutingFw],
Exprs: exprs, Exprs: expression,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil return nil
@@ -793,13 +697,11 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
if err := r.decrementSetCounter(rule); err != nil { log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
return fmt.Errorf("decrement set counter: %w", err)
}
} }
return nil return nil
@@ -1002,18 +904,20 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if pair.Masquerade { if err := r.removeNatRule(pair); err != nil {
if err := r.removeNatRule(pair); err != nil { return fmt.Errorf("remove prerouting rule: %w", err)
return fmt.Errorf("remove prerouting rule: %w", err) }
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -1021,10 +925,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
} }
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil return nil
} }
@@ -1032,19 +936,16 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil { err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination) log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else { } else {
log.Debugf("prerouting rule %s not found", ruleKey) log.Debugf("nftables: prerouting rule %s not found", ruleKey)
} }
return nil return nil
@@ -1056,7 +957,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains { for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain) rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil { if err != nil {
return fmt.Errorf(" unable to list rules: %v", err) return fmt.Errorf("nftables: unable to list rules: %v", err)
} }
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
@@ -1330,54 +1231,13 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { // generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName()) func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
if err != nil { var offset uint32
return fmt.Errorf("get set %s: %w", set.HashedName(), err) if source {
} offset = 12 // src offset
} else {
elements := convertPrefixesToSet(prefixes) offset = 16 // dst offset
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
} }
ones := prefix.Bits() ones := prefix.Bits()
@@ -1464,48 +1324,3 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs return exprs
} }
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range rtr.chains { for _, chain := range rtr.chains {
if chain.Name == chainNameManglePrerouting { if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was added // Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting]) rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules") require.NoError(t, err, "should list rules")
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was removed // Verify the rule was removed
found = false found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting]) rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal") require.NoError(t, err, "should list rules after removal")
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName() setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) set, err := r.createIpSet(setName, tt.sources)
if err != nil { if err != nil {
t.Logf("Failed to create IP set: %v", err) t.Logf("Failed to create IP set: %v", err)
printNftSets() printNftSets()

View File

@@ -15,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -24,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -40,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Close cleans up the firewall manager by removing all rules and closing trackers // Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -21,7 +22,7 @@ const (
firewallRuleName = "Netbird" firewallRuleName = "Netbird"
) )
// Close cleans up the firewall manager by removing all rules and closing trackers // Reset firewall to the default state
func (m *Manager) Close(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -31,14 +32,17 @@ func (m *Manager) Close(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if fwder := m.forwarder.Load(); fwder != nil { if fwder := m.forwarder.Load(); fwder != nil {

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }

View File

@@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
@@ -11,7 +12,7 @@ import (
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {

View File

@@ -3,7 +3,6 @@ package conntrack
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@@ -20,10 +19,6 @@ const (
DefaultICMPTimeout = 30 * time.Second DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -34,7 +29,7 @@ type ICMPConnKey struct {
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID) return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@@ -55,72 +50,6 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
@@ -164,64 +93,30 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound( func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound( func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
srcIP netip.Addr, t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track( func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
} }
typ, code := typecode.Type(), typecode.Code() typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -237,13 +132,12 @@ func (t *ICMPTracker) track(
ICMPCode: code, ICMPCode: code,
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
@@ -294,7 +188,7 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()

View File

@@ -23,11 +23,11 @@ const (
) )
const ( const (
TCPFin uint8 = 0x01
TCPSyn uint8 = 0x02 TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04 TCPRst uint8 = 0x04
TCPPush uint8 = 0x08 TCPPush uint8 = 0x08
TCPAck uint8 = 0x10
TCPUrg uint8 = 0x20 TCPUrg uint8 = 0x20
) )
@@ -41,7 +41,7 @@ const (
) )
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int32 type TCPState int
func (s TCPState) String() string { func (s TCPState) String() string {
switch s { switch s {
@@ -89,25 +89,22 @@ const (
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16 SourcePort uint16
DestPort uint16 DestPort uint16
state atomic.Int32 State TCPState
tombstone atomic.Bool established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
} }
// GetState safely retrieves the current state // IsEstablished safely checks if connection is established
func (t *TCPConnTrack) GetState() TCPState { func (t *TCPConnTrack) IsEstablished() bool {
return TCPState(t.state.Load()) return t.established.Load()
} }
// SetState safely updates the current state // SetEstablished safely sets the established state
func (t *TCPConnTrack) SetState(state TCPState) { func (t *TCPConnTrack) SetEstablished(state bool) {
t.state.Store(int32(state)) t.established.Store(state)
}
// CompareAndSwapState atomically changes the state from old to new if current == old
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
return t.state.CompareAndSwap(int32(old), int32(newState))
} }
// IsTombstone safely checks if the connection is marked for deletion // IsTombstone safely checks if the connection is marked for deletion
@@ -128,17 +125,13 @@ type TCPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
waitTimeout := TimeWaitTimeout
if timeout == 0 { if timeout == 0 {
timeout = DefaultTCPTimeout timeout = DefaultTCPTimeout
} else {
waitTimeout = timeout / 45
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -149,7 +142,6 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
timeout: timeout, timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -157,7 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -170,7 +162,12 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
t.mutex.RUnlock() t.mutex.RUnlock()
if exists { if exists {
t.updateState(key, conn, flags, direction, size) conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@@ -178,22 +175,22 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
} }
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 { if exists {
return return
} }
@@ -208,11 +205,11 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
DestPort: dstPort, DestPort: dstPort,
} }
conn.established.Store(false)
conn.tombstone.Store(false) conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key) t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn t.connections[key] = conn
@@ -222,7 +219,7 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{ key := ConnKey{
SrcIP: dstIP, SrcIP: dstIP,
DstIP: srcIP, DstIP: srcIP,
@@ -234,125 +231,129 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists || conn.IsTombstone() { if !exists {
return false return false
} }
currentState := conn.GetState() // Handle RST flag specially - it always causes transition to closed
if !t.isValidStateForFlags(currentState, flags) { if flags&TCPRst != 0 {
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) if conn.IsTombstone() {
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true return true
} }
return false
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
} }
t.updateState(key, conn, flags, nftypes.Ingress, size) conn.Lock()
return true t.updateState(key, conn, flags, false)
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
return isEstablished || isValidState
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
currentState := conn.GetState() state := conn.State
defer func() {
if flags&TCPRst != 0 { if state != conn.State {
if conn.CompareAndSwapState(currentState, TCPStateClosed) { t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
return }()
}
var newState TCPState switch state {
switch currentState {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress { conn.State = TCPStateSynSent
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
} }
case TCPStateSynSent: case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction { if isOutbound {
newState = TCPStateEstablished conn.State = TCPStateEstablished
conn.SetEstablished(true)
} else { } else {
// Simultaneous open // Simultaneous open
newState = TCPStateSynReceived conn.State = TCPStateSynReceived
} }
} }
case TCPStateSynReceived: case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 { if flags&TCPAck != 0 && flags&TCPSyn == 0 {
if packetDir == conn.Direction { conn.State = TCPStateEstablished
newState = TCPStateEstablished conn.SetEstablished(true)
}
} }
case TCPStateEstablished: case TCPStateEstablished:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
if packetDir == conn.Direction { if isOutbound {
newState = TCPStateFinWait1 conn.State = TCPStateFinWait1
} else { } else {
newState = TCPStateCloseWait conn.State = TCPStateCloseWait
} }
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
if packetDir != conn.Direction { switch {
switch { case flags&TCPFin != 0 && flags&TCPAck != 0:
case flags&TCPFin != 0 && flags&TCPAck != 0: conn.State = TCPStateClosing
newState = TCPStateClosing case flags&TCPFin != 0:
case flags&TCPFin != 0: conn.State = TCPStateFinWait2
newState = TCPStateClosing case flags&TCPAck != 0:
case flags&TCPAck != 0: conn.State = TCPStateFinWait2
newState = TCPStateFinWait2 case flags&TCPRst != 0:
} conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
newState = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
newState = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
newState = TCPStateLastAck conn.State = TCPStateLastAck
} }
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
newState = TCPStateClosed conn.State = TCPStateClosed
}
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone() conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) // Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
} }
} }
} }
@@ -362,22 +363,18 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) { if !isValidFlagCombination(flags) {
return false return false
} }
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
}
return true
}
switch state { switch state {
case TCPStateNew: case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0 return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent: case TCPStateSynSent:
// TODO: support simultaneous open
return flags&TCPSyn != 0 && flags&TCPAck != 0 return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived: case TCPStateSynReceived:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateEstablished: case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateFinWait1: case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0 return flags&TCPFin != 0 || flags&TCPAck != 0
@@ -394,7 +391,9 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
case TCPStateLastAck: case TCPStateLastAck:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateClosed: case TCPStateClosed:
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK // Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0 return flags&TCPAck != 0
} }
return false return false
@@ -425,24 +424,23 @@ func (t *TCPTracker) cleanup() {
} }
var timeout time.Duration var timeout time.Duration
currentState := conn.GetState() switch {
switch currentState { case conn.State == TCPStateTimeWait:
case TCPStateTimeWait: timeout = TimeWaitTimeout
timeout = t.waitTimeout case conn.IsEstablished():
case TCPStateEstablished:
timeout = t.timeout timeout = t.timeout
default: default:
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
if conn.timeoutExceeded(timeout) { if conn.timeoutExceeded(timeout) {
// Return IPs to pool
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change // event already handled by state change
if currentState != TCPStateTimeWait { if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }

View File

@@ -1,83 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
"time"
)
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -5,7 +5,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -125,6 +124,9 @@ func TestTCPStateMachine(t *testing.T) {
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
}, },
}, },
{ {
@@ -215,446 +217,97 @@ func TestRSTHandling(t *testing.T) {
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.GetState()) require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
} }
}) })
} }
} }
func TestTCPRetransmissions(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test SYN retransmission
t.Run("SYN Retransmission", func(t *testing.T) {
// Initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Retransmit SYN (should not affect the state machine)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Verify we're still in SYN-SENT state
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateSynSent, conn.GetState())
// Complete the handshake
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Verify we're in ESTABLISHED state
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test ACK retransmission in established state
t.Run("ACK Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Retransmit ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test FIN retransmission
t.Run("FIN Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Retransmit FIN (should not change state)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
})
}
func TestTCPDataTransfer(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Data Transfer", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Receive ACK for data
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
require.True(t, valid)
// Receive data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
require.True(t, valid)
// Send ACK for received data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
})
}
func TestTCPHalfClosedConnections(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test half-closed connection: local end closes, remote end continues sending data
t.Run("Local Close, Remote Data", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Remote end can still send data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
require.True(t, valid)
// We can still ACK their data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Receive FIN from remote end
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain TIME-WAIT (waiting for possible retransmissions)
require.Equal(t, TCPStateTimeWait, conn.GetState())
})
// Test half-closed connection: remote end closes, local end continues sending data
t.Run("Remote Close, Local Data", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Receive FIN from remote
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// We can still send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Remote can still ACK our data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Send our FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Receive final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
})
}
func TestTCPAbnormalSequences(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test handling of unsolicited RST in various states
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
// Send SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive unsolicited RST (without proper ACK)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
// Receive RST with proper ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.Equal(t, TCPStateClosed, conn.GetState())
require.True(t, conn.IsTombstone())
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Connection Timeout", func(t *testing.T) {
// Establish a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Wait for the connection to timeout
time.Sleep(2 * shortTimeout)
// Force cleanup
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after timeout")
})
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Complete the connection close to enter TIME_WAIT
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// TIME_WAIT should have its own timeout value (usually 2*MSL)
// For the test, we're using a short timeout
time.Sleep(2 * shortTimeout)
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
})
}
func TestSynFlood(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
basePort := uint16(10000)
dstPort := uint16(80)
// Create a large number of SYN packets to simulate a SYN flood
for i := uint16(0); i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
}
// Check that we're tracking all connections
require.Equal(t, 1000, len(tracker.connections))
// Now simulate SYN timeout
var oldConns int
tracker.mutex.Lock()
for _, conn := range tracker.connections {
if conn.GetState() == TCPStateSynSent {
// Make the connection appear old
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
oldConns++
}
}
tracker.mutex.Unlock()
require.Equal(t, 1000, oldConns)
// Run cleanup
tracker.cleanup()
// Check that stale connections were cleaned up
require.Equal(t, 0, len(tracker.connections))
}
func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := netip.MustParseAddr("100.64.0.1")
serverIP := netip.MustParseAddr("100.64.0.2")
clientPort := uint16(12345)
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{
SrcIP: clientIP,
DstIP: serverIP,
SrcPort: clientPort,
DstPort: serverPort,
}
tracker.mutex.RLock()
conn := tracker.connections[key]
tracker.mutex.RUnlock()
require.NotNil(t, conn)
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
// 2. Server sends SYN-ACK response
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
// Server sends data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
}
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
} }

View File

@@ -110,7 +110,6 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
DestPort: dstPort, DestPort: dstPort,
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn t.connections[key] = conn
@@ -165,7 +164,7 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
} }

View File

@@ -4,9 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -19,7 +17,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
@@ -32,16 +29,14 @@ const (
) )
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip tcpip.Address ip net.IP
netstack bool netstack bool
} }
@@ -71,11 +66,12 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: iface.Address().Network.Bits(), PrefixLen: ones,
}, },
} }
@@ -115,7 +111,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), ip: iface.Address().IP,
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -166,39 +162,8 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) { if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
} }
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
key := buildKey(srcIP, dstIP, srcPort, dstPort)
f.ruleIdMap.LoadOrStore(key, ruleID)
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
return value.([]byte), true
}
return nil, false
}
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return
}
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
}
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
return conntrack.ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
} }
flowID := uuid.New() flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root // TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
defer func() { defer func() {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err) f.logger.Debug("Failed to close ICMP socket: %v", err)
} }
}() }()
@@ -52,37 +52,36 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil { if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true return true
} }
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v", f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size() f.handleEchoResponse(icmpHdr, conn, id)
txBytes := f.handleEchoResponse(icmpHdr, conn, id) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err) f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return 0 return
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response) n, _, err := conn.ReadFrom(response)
if err != nil { if err != nil {
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("forwarder: Failed to read ICMP response: %v", err) f.logger.Error("Failed to read ICMP response: %v", err)
} }
return 0 return
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -101,54 +100,28 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...) fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err) f.logger.Error("Failed to inject ICMP response: %v", err)
return 0 return
} }
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v", f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) { func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
var rxPackets, txPackets uint64 f.flowLogger.StoreEvent(nftypes.EventFields{
if rxBytes > 0 {
rxPackets = 1
}
if txBytes > 0 {
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: srcIp, SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: dstIp, DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType, ICMPType: icmpType,
ICMPCode: icmpCode, ICMPCode: icmpCode,
RxBytes: rxBytes, // TODO: get packets/bytes
TxBytes: txBytes, })
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
} }

View File

@@ -6,10 +6,8 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -25,11 +23,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
flowID := uuid.New() flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool var success bool
defer func() { defer func() {
if !success { if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
} }
}() }()
@@ -67,97 +65,67 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx) ctx, cancel := context.WithCancel(f.ctx)
defer cancel() defer cancel()
go func() { errChan := make(chan error, 2)
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() { go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) _, err := io.Copy(outConn, inConn)
cancel() errChan <- err
wg.Done()
}() }()
go func() { go func() {
_, err := io.Copy(inConn, outConn)
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) errChan <- err
cancel()
wg.Done()
}() }()
wg.Wait() select {
case <-ctx.Done():
if errInToOut != nil { f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
if !isClosedError(errInToOut) { return
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
} }
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: srcIp, SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: dstIp, DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort, SourcePort: id.RemotePort,
DestPort: id.LocalPort, DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if typ == nftypes.TypeStart { if ep != nil {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
fields.RuleID = ruleId // fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
} }
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
} }
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)

View File

@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
flowID := uuid.New() flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool var success bool
defer func() { defer func() {
if !success { if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
} }
}() }()
@@ -199,6 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
@@ -211,94 +212,68 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil && !isClosedError(err) { if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) { if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
ep.Close() ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}() }()
var wg sync.WaitGroup errChan := make(chan error, 2)
wg.Add(2)
var txBytes, rxBytes int64
var outboundErr, inboundErr error
// outbound->inbound: copy from pConn.conn to pConn.outConn
go func() { go func() {
defer wg.Done() errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}() }()
// inbound->outbound: copy from pConn.outConn to pConn.conn
go func() { go func() {
defer wg.Done() errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}() }()
wg.Wait() select {
case <-ctx.Done():
if outboundErr != nil && !isClosedError(outboundErr) { f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr) return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
} }
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
}
var rxPackets, txPackets uint64
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
rxPackets = udpStats.PacketsSent.Value()
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
} }
// sendUDPEvent stores flow events for UDP connections // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: srcIp, SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: dstIp, DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort, SourcePort: id.RemotePort,
DestPort: id.LocalPort, DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if typ == nftypes.TypeStart { if ep != nil {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
fields.RuleID = ruleId // fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
} }
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
} }
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)
@@ -313,20 +288,18 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
return time.Since(lastSeen) return time.Since(lastSeen)
} }
// copy reads from src and writes to dst. func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
bufp := bufPool.Get().(*[]byte) bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp) defer bufPool.Put(bufp)
buffer := *bufp buffer := *bufp
var totalBytes int64 = 0
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
return totalBytes, ctx.Err() return ctx.Err()
} }
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return totalBytes, fmt.Errorf("set read deadline: %w", err) return fmt.Errorf("set read deadline: %w", err)
} }
n, err := src.Read(buffer) n, err := src.Read(buffer)
@@ -334,15 +307,14 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
if isTimeout(err) { if isTimeout(err) {
continue continue
} }
return totalBytes, fmt.Errorf("read from %s: %w", direction, err) return fmt.Errorf("read from %s: %w", direction, err)
} }
nWritten, err := dst.Write(buffer[:n]) _, err = dst.Write(buffer[:n])
if err != nil { if err != nil {
return totalBytes, fmt.Errorf("write to %s: %w", direction, err) return fmt.Errorf("write to %s: %w", direction, err)
} }
totalBytes += int64(nWritten)
c.updateLastSeen() c.updateLastSeen()
} }
} }

View File

@@ -14,13 +14,8 @@ import (
type localIPManager struct { type localIPManager struct {
mu sync.RWMutex mu sync.RWMutex
// fixed-size high array for upper byte of a IPv4 address // Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [256]*ipv4LowBitmap ipv4Bitmap [1 << 16]uint32
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
} }
func newLocalIPManager() *localIPManager { func newLocalIPManager() *localIPManager {
@@ -32,61 +27,35 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
if ipv4 == nil { if ipv4 == nil {
return return
} }
high := uint16(ipv4[0]) high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
} }
func (m *localIPManager) checkBitmapBit(ip []byte) bool { func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0]) high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3]) low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
} }
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error { func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,13 +73,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
addr, ok := netip.AddrFromSlice(ip) if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -123,14 +86,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[netip.Addr]struct{}) ipv4Set := make(map[string]struct{})
var ipv4Addresses []netip.Addr var ipv4Addresses []string
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} high := uint16(127) << 8
for i := 0; i < 8192; i++ { for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF newIPv4Bitmap[high|i] = 0xffffffff
} }
if iface != nil { if iface != nil {
@@ -157,12 +120,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
return m.checkBitmapBit(ip.AsSlice()) if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false
} }

View File

@@ -20,8 +20,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@@ -29,8 +32,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@@ -38,8 +44,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@@ -47,8 +56,11 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@@ -56,26 +68,23 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
}, },
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
},
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"), IP: net.ParseIP("fe80::1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
@@ -183,8 +192,10 @@ func BenchmarkIPChecks(b *testing.B) {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
} }
// Setup bitmap // Setup bitmap version
bitmapManager := newLocalIPManager() bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip) bitmapManager.setBitmapBit(ip)
} }
@@ -237,7 +248,7 @@ func BenchmarkWGPosition(b *testing.B) {
// Create two managers - one checks WG IP first, other checks it last // Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) { b.Run("WG_First", func(b *testing.B) {
bm := newLocalIPManager() bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP) bm.setBitmapBit(wgIP)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@@ -246,7 +257,7 @@ func BenchmarkWGPosition(b *testing.B) {
}) })
b.Run("WG_Last", func(b *testing.B) { b.Run("WG_Last", func(b *testing.B) {
bm := newLocalIPManager() bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first // Fill with other IPs first
for i := 0; i < 15; i++ { for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))

View File

@@ -29,15 +29,14 @@ func (r *PeerRule) ID() string {
} }
type RouteRule struct { type RouteRule struct {
id string id string
mgmtId []byte mgmtId []byte
sources []netip.Prefix sources []netip.Prefix
dstSet firewall.Set destination netip.Prefix
destinations []netip.Prefix proto firewall.Protocol
proto firewall.Protocol srcPort *firewall.Port
srcPort *firewall.Port dstPort *firewall.Port
dstPort *firewall.Port action firewall.Action
action firewall.Action
} }
// ID returns the rule id // ID returns the rule id

View File

@@ -38,8 +38,11 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -195,12 +198,12 @@ func TestTracePacket(t *testing.T) {
m.forwarder.Store(&forwarder.Forwarder{}) m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err) require.NoError(t, err)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -219,12 +222,12 @@ func TestTracePacket(t *testing.T) {
m.nativeRouter.Store(false) m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err) require.NoError(t, err)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -242,7 +245,7 @@ func TestTracePacket(t *testing.T) {
m.nativeRouter.Store(true) m.nativeRouter.Store(true)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -260,7 +263,7 @@ func TestTracePacket(t *testing.T) {
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -422,8 +425,8 @@ func TestTracePacket(t *testing.T) {
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")), require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP") "100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")), require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"192.168.17.2 should not be recognized as a local IP") "172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder() pb := tc.packetBuilder()

View File

@@ -39,12 +39,8 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces. // EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
// Default off as it might be security risk because sockets listening on localhost only will become accessible. // Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
@@ -53,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule type RuleSet map[string]PeerRule
type RouteRules []*RouteRule type RouteRules []RouteRule
func (r RouteRules) Sort() { func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b *RouteRule) int { slices.SortStableFunc(r, func(a, b RouteRule) int {
// Deny rules come first // Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1 return -1
@@ -75,6 +71,7 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -102,8 +99,6 @@ type Manager struct {
forwarder atomic.Pointer[forwarder.Forwarder] forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
blockRule firewall.Rule
} }
// decoder for packages // decoder for packages
@@ -151,11 +146,6 @@ func parseCreateEnv() (bool, bool) {
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
} }
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
} }
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
@@ -211,35 +201,41 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
} }
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
}
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err) return nil, fmt.Errorf("set filter: %w", err)
} }
return m, nil return m, nil
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder.Load() == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err) return fmt.Errorf("parse wireguard network: %w", err)
} }
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering( if _, err := m.AddRouteFiltering(
nil, nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewall.Network{Prefix: wgPrefix}, wgPrefix,
firewall.ProtocolALL, firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionDrop, firewall.ActionDrop,
) ); err != nil {
if err != nil { return fmt.Errorf("block wg nte : %w", err)
return nil, fmt.Errorf("block wg nte : %w", err)
} }
// TODO: Block networks that we're a client of // TODO: Block networks that we're a client of
return rule, nil return nil
} }
func (m *Manager) determineRouting() error { func (m *Manager) determineRouting() error {
@@ -277,7 +273,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil: case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
@@ -334,10 +330,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
@@ -421,23 +413,10 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination firewall.Network, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
sPort, dPort *firewall.Port, sPort *firewall.Port,
action firewall.Action, dPort *firewall.Port,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
@@ -447,39 +426,34 @@ func (m *Manager) addRouteFiltering(
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, destination: destination,
proto: proto, proto: proto,
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
} }
m.routeRules = append(m.routeRules, &rule) m.mutex.Lock()
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.ID() ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID return r.id == ruleID
}) })
if idx < 0 { if idx < 0 {
@@ -535,52 +509,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.UpdateSet(set, prefixes)
}
m.mutex.Lock()
defer m.mutex.Unlock()
var matches []*RouteRule
for _, rule := range m.routeRules {
if rule.dstSet == set {
matches = append(matches, rule)
}
}
if len(matches) == 0 {
return fmt.Errorf("no route rule found for set: %s", set)
}
destinations := matches[0].destinations
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
if cmp != 0 {
return cmp
}
return a.Bits() - b.Bits()
})
destinations = slices.Compact(destinations)
for _, rule := range matches {
rule.destinations = destinations
}
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
return nil
}
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size) return m.processOutgoingHooks(packetData, size)
@@ -618,8 +546,9 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
// for netflow we keep track even if the firewall is stateless if m.stateful {
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
}
return false return false
} }
@@ -671,7 +600,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
} }
} }
@@ -684,7 +613,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
} }
} }
@@ -729,8 +658,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
valid, fragment := m.isValidPacket(d, packetData) if !m.isValidPacket(d, packetData) {
if !valid {
return true return true
} }
@@ -740,13 +668,6 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return true return true
} }
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked. // This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
@@ -757,7 +678,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size) return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
} }
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
} }
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
@@ -788,27 +709,29 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// If requested we pass local traffic to internal interfaces to the forwarder. // if running in netstack mode we need to pass this to the forwarder
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. if m.netstack {
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { return m.handleNetstackLocalTraffic(packetData)
return m.handleForwardedLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP, ruleID, size) m.trackInbound(d, srcIP, dstIP, ruleID, size)
// pass to either native or virtual stack (to be picked up by listeners)
return false return false
} }
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load() if !m.localForwarding {
if fwd == nil { // pass to virtual tcp/ip stack to be picked up by listeners
return false
}
if m.forwarder.Load() == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true return true
} }
if err := fwd.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err) m.logger.Error("Failed to inject local packet: %v", err)
} }
@@ -818,7 +741,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic. // handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled.Load() { if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
@@ -828,15 +751,13 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
// Pass to native stack if native router is enabled or forced // Pass to native stack if native router is enabled or forced
if m.nativeRouter.Load() { if m.nativeRouter.Load() {
m.trackInbound(d, srcIP, dstIP, nil, size)
return false return false
} }
proto, pnum := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
if !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
@@ -851,23 +772,13 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
}) })
return true return true
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
fwd := m.forwarder.Load() if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
if fwd == nil { m.logger.Error("Failed to inject incoming packet: %v", err)
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
} else {
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
}
} }
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture // Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
@@ -898,32 +809,17 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
} }
} }
// isValidPacket checks if the packet is valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
// It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace("couldn't decode packet, err: %s", err) m.logger.Trace("couldn't decode packet, err: %s", err)
return false, false return false
} }
l := len(d.decoded) if len(d.decoded) < 2 {
m.logger.Trace("packet doesn't have network and transport layers")
// L3 and L4 are mandatory return false
if l >= 2 {
return true, false
} }
return true
// Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
ip4 := d.ip4
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
return true, true
}
}
m.logger.Trace("packet doesn't have network and transport layers")
return false, false
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
@@ -1063,15 +959,8 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
destMatched := false if !rule.destination.Contains(dstAddr) {
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
destMatched = true
break
}
}
if !destMatched {
return false return false
} }
@@ -1099,6 +988,11 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not
@@ -1168,22 +1062,7 @@ func (m *Manager) EnableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.determineRouting(); err != nil { return m.determineRouting()
return fmt.Errorf("determine routing: %w", err)
}
if m.forwarder.Load() == nil {
return nil
}
rule, err := m.blockInvalidRouted(m.wgIface)
if err != nil {
return fmt.Errorf("block invalid routed: %w", err)
}
m.blockRule = rule
return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
@@ -1208,12 +1087,5 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped") log.Debug("forwarder stopped")
if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil return nil
} }

View File

@@ -174,6 +174,11 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@@ -214,6 +219,11 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@@ -257,6 +267,11 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -289,6 +304,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -302,6 +321,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -316,6 +339,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -329,6 +356,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -342,6 +373,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -355,6 +390,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -369,6 +408,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -383,6 +426,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -396,6 +443,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -542,6 +593,11 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -625,6 +681,11 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -736,6 +797,11 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -816,6 +882,11 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@@ -961,8 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
dst := fw.Network{Prefix: r.dest} _, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -15,12 +15,15 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100") localIP := net.ParseIP("100.10.0.100")
wgNet := netip.MustParsePrefix("100.10.0.0/16") wgNet := &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@@ -39,6 +42,8 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -183,281 +188,6 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionAccept, ruleAction: fw.ActionAccept,
shouldBeBlocked: true, shouldBeBlocked: true,
}, },
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Block TCP traffic outside port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Edge Case - Port at Range Boundary",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8100,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "UDP Port Range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 5060,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
// New drop test cases
{
name: "Drop TCP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop UDP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop ICMP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop all traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop traffic from multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: false,
},
{
name: "Drop TCP traffic with source port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 32100,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Mixed rule - drop specific port but allow other ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
} }
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
@@ -468,28 +198,6 @@ func TestPeerACLFiltering(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil, nil,
net.ParseIP(tc.ruleIP), net.ParseIP(tc.ruleIP),
@@ -575,13 +283,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix(network) localIP, wgNet, err := net.ParseCIDR(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: wgNet.Addr(), IP: localIP,
Network: wgNet, Network: wgNet,
} }
}, },
@@ -594,8 +303,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load()) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load()) require.False(tb, manager.nativeRouter.Load())
@@ -612,7 +321,7 @@ func TestRouteACLFiltering(t *testing.T) {
type rule struct { type rule struct {
sources []netip.Prefix sources []netip.Prefix
dest fw.Network dest netip.Prefix
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -638,7 +347,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -654,7 +363,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -670,7 +379,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -686,7 +395,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53, dstPort: 53,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{53}}, dstPort: &fw.Port{Values: []uint16{53}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -700,7 +409,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@@ -715,7 +424,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -731,7 +440,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -747,7 +456,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -763,7 +472,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -779,7 +488,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -798,7 +507,7 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -812,7 +521,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@@ -827,13 +536,33 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
shouldPass: true, shouldPass: true,
}, },
{
name: "Multiple source networks with mismatched protocol",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
// Should not match TCP rule
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{ {
name: "Allow multiple destination ports", name: "Allow multiple destination ports",
srcIP: "100.10.0.1", srcIP: "100.10.0.1",
@@ -843,7 +572,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -859,7 +588,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -875,7 +604,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
@@ -892,7 +621,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -911,7 +640,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999, dstPort: 7999,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -930,7 +659,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -949,7 +678,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -971,7 +700,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8100, dstPort: 8100,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -990,7 +719,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 5060, dstPort: 5060,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -1009,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -1028,7 +757,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -1044,7 +773,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
@@ -1062,158 +791,17 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
shouldPass: false, shouldPass: false,
}, },
{
name: "Drop empty destination set",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Set: fw.Set{}},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
action: fw.ActionDrop,
},
shouldPass: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: false,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil, nil,
tc.rule.sources, tc.rule.sources,
@@ -1248,7 +836,7 @@ func TestRouteACLOrder(t *testing.T) {
name string name string
rules []struct { rules []struct {
sources []netip.Prefix sources []netip.Prefix
dest fw.Network dest netip.Prefix
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -1269,7 +857,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Drop rules take precedence over accept", name: "Drop rules take precedence over accept",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest fw.Network dest netip.Prefix
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -1278,7 +866,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept rule added first // Accept rule added first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 443}}, dstPort: &fw.Port{Values: []uint16{80, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -1286,7 +874,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop rule added second but should be evaluated first // Drop rule added second but should be evaluated first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -1324,7 +912,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Multiple drop rules take precedence", name: "Multiple drop rules take precedence",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest fw.Network dest netip.Prefix
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -1333,14 +921,14 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept all // Accept all
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
{ {
// Drop specific port // Drop specific port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -1348,7 +936,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop different port // Drop different port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -1427,50 +1015,3 @@ func TestRouteACLOrder(t *testing.T) {
}) })
} }
} }
func TestRouteACLSet(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -1,6 +1,7 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -20,11 +21,10 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/management/domain"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
@@ -271,8 +271,11 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"), Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -282,6 +285,10 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@@ -389,6 +396,10 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@@ -498,6 +509,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
@@ -696,203 +712,3 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}) })
} }
} }
func TestUpdateSetMerge(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
// Update the set with initial prefixes
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Test initial prefixes work
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP1 := netip.MustParseAddr("10.0.0.100")
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
newPrefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.1.0.0/24"),
}
err = manager.UpdateSet(set, newPrefixes)
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
// Check that new prefixes are included
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
// Verify the rule has all prefixes
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
}
func TestUpdateSetDeduplication(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
}
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Check the internal state for deduplication
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
// Check the prefixes are correct
expectedPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
for i, prefix := range expectedPrefixes {
require.True(t, r.destinations[i] == prefix,
"Prefix should match expected value")
}
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
// Test with overlapping prefixes of different sizes
overlappingPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/16"), // More general
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
netip.MustParsePrefix("192.168.0.0/16"), // More general
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
}
err = manager.UpdateSet(set, overlappingPrefixes)
require.NoError(t, err)
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")
// Verify they're sorted correctly (more specific prefixes should come first)
prefixes := make([]string, 0, len(r.destinations))
for _, p := range r.destinations {
prefixes = append(prefixes, p.String())
}
// Check sorted order
require.Equal(t, []string{
"10.0.0.0/16",
"10.0.0.0/24",
"192.168.0.0/16",
"192.168.1.0/24",
}, prefixes, "Prefixes should be sorted")
}
}
manager.mutex.RUnlock()
// Test functionality with all prefixes
testCases := []struct {
dstIP netip.Addr
expected bool
desc string
}{
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
}
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
// NewUDPMuxDefault creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = getLogger() params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
mux := &UDPMuxDefault{ mux := &UDPMuxDefault{
@@ -455,9 +455,3 @@ func newBufferHolder(size int) *bufferHolder {
buf: make([]byte, size), buf: make([]byte, size),
} }
} }
func getLogger() logging.LeveledLogger {
fac := logging.NewDefaultLoggerFactory()
//fac.Writer = log.StandardLogger().Writer()
return fac.NewLogger("ice")
}

View File

@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = getLogger() params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
if params.XORMappedAddrCacheTTL == 0 { if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25 params.XORMappedAddrCacheTTL = time.Second * 25
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a) { if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@@ -1,17 +0,0 @@
package configurer
import (
"net"
"net/netip"
)
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -5,7 +5,6 @@ package configurer
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -13,8 +12,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -46,7 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -55,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps), AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
@@ -92,10 +89,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
return nil return nil
} }
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
ipNet := net.IPNet{ _, ipNet, err := net.ParseCIDR(allowedIP)
IP: allowedIP.Addr().AsSlice(), if err != nil {
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()), return err
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -106,7 +103,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet}, AllowedIPs: []net.IPNet{*ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@@ -119,10 +116,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
return nil return nil
} }
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
ipNet := net.IPNet{ _, ipNet, err := net.ParseCIDR(allowedIP)
IP: allowedIP.Addr().AsSlice(), if err != nil {
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()), return fmt.Errorf("parse allowed IP: %w", err)
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -190,11 +187,7 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
if err != nil { if err != nil {
return err return err
} }
defer func() { defer wg.Close()
if err := wg.Close(); err != nil {
log.Errorf("Failed to close wgctrl client: %v", err)
}
}()
// validate if device with name exists // validate if device with name exists
_, err = wg.Device(c.deviceName) _, err = wg.Device(c.deviceName)
@@ -208,71 +201,14 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) FullStats() (*Stats, error) { func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
wg, err := wgctrl.New() peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("wgctl: %w", err) return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
} }
defer func() { return WGStats{
err = wg.Close() LastHandshake: peer.LastHandshakeTime,
if err != nil { TxBytes: peer.TransmitBytes,
log.Errorf("Got error while closing wgctl: %v", err) RxBytes: peer.ReceiveBytes,
} }, nil
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
} }

View File

@@ -1,11 +1,9 @@
package configurer package configurer
import ( import (
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
@@ -19,20 +17,6 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
@@ -68,7 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -77,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps), AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
@@ -107,10 +91,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
ipNet := net.IPNet{ _, ipNet, err := net.ParseCIDR(allowedIP)
IP: allowedIP.Addr().AsSlice(), if err != nil {
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()), return err
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -121,7 +105,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet}, AllowedIPs: []net.IPNet{*ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@@ -131,7 +115,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet() ipc, err := c.device.IpcGet()
if err != nil { if err != nil {
return err return err
@@ -154,8 +138,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
foundPeer := false foundPeer := false
removedAllowedIP := false removedAllowedIP := false
ip := allowedIP.String()
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@@ -178,8 +160,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
// Append the line to the output string // Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") { if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=") allowedIP := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIPStr) _, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil { if err != nil {
return err return err
} }
@@ -196,15 +178,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
@@ -244,75 +217,91 @@ func (t *WGUSPConfigurer) Close() {
} }
} }
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) { func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return nil, fmt.Errorf("ipc get: %w", err) return WGStats{}, fmt.Errorf("ipc get: %w", err)
} }
return parseTransfers(ipc) stats, err := findPeerInfo(ipc, peerKey, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
} }
func parseTransfers(ipc string) (map[string]WGStats, error) { func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
stats := make(map[string]WGStats) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
var ( if err != nil {
currentKey string return nil, fmt.Errorf("parse key: %w", err)
currentStats WGStats }
hasPeer bool
) hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipc, "\n")
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key, // If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop. // this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") { if strings.HasPrefix(line, "public_key=") && foundPeer {
peerID := strings.TrimPrefix(line, "public_key=") break
h, err := hex.DecodeString(peerID)
if err != nil {
return nil, fmt.Errorf("decode peerID: %w", err)
}
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
} }
if !hasPeer { // Identify the peer with the specific public key
continue if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
} }
key := strings.SplitN(line, "=", 2) for _, key := range searchConfigKeys {
if len(key) != 2 { if foundPeer && strings.HasPrefix(line, key+"=") {
continue v := strings.SplitN(line, "=", 2)
} configFound[v[0]] = v[1]
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
} }
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
} }
} }
return stats, nil // todo: use multierr
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
} }
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -366,154 +355,9 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String() return sb.String()
} }
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark return nbnet.NetbirdFwmark
} }
return 0 return 0
} }
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -2,8 +2,10 @@ package configurer
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -32,35 +34,58 @@ errno=0
` `
func Test_parseTransfers(t *testing.T) { func Test_findPeerInfo(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
peerKey string peerKey string
want WGStats searchKeys []string
want map[string]string
wantErr bool
}{ }{
{ {
name: "single", name: "single",
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
want: WGStats{ searchKeys: []string{"tx_bytes"},
TxBytes: 0, want: map[string]string{
RxBytes: 0, "tx_bytes": "38333",
}, },
wantErr: false,
}, },
{ {
name: "multiple", name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
want: WGStats{ searchKeys: []string{"tx_bytes", "rx_bytes"},
TxBytes: 38333, want: map[string]string{
RxBytes: 2224, "tx_bytes": "38333",
"rx_bytes": "2224",
}, },
wantErr: false,
}, },
{ {
name: "lastpeer", name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58", peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
want: WGStats{ searchKeys: []string{"tx_bytes", "rx_bytes"},
TxBytes: 1212111, want: map[string]string{
RxBytes: 1929999999, "tx_bytes": "1212111",
"rx_bytes": "1929999999",
}, },
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -71,19 +96,9 @@ func Test_parseTransfers(t *testing.T) {
key, err := wgtypes.NewKey(res) key, err := wgtypes.NewKey(res)
require.NoError(t, err) require.NoError(t, err)
stats, err := parseTransfers(ipcFixture) got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
if err != nil { assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
require.NoError(t, err) assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
}) })
} }
} }

View File

@@ -1,24 +0,0 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -24,7 +24,6 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@@ -33,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -41,7 +40,6 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@@ -51,13 +49,6 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@@ -1,6 +1,7 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -23,6 +24,9 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@@ -51,11 +51,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -2,7 +2,6 @@ package device
import ( import (
"net" "net"
"net/netip"
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -12,11 +11,10 @@ import (
type WGConfigurer interface { type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
} }

View File

@@ -64,15 +64,7 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -43,7 +43,6 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance
@@ -112,14 +111,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
} }
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional. // Endpoint is optional
// If allowedIps is given it will be added to the existing ones.
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) netIPNets := prefixesToIPNets(allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@@ -132,7 +131,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
} }
// AddAllowedIP adds a prefix to the allowed IPs list of peer // AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -141,7 +140,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
} }
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer // RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -186,6 +185,7 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
@@ -212,13 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device() return w.tun.Device()
} }
// GetStats returns the last handshake time, rx and tx bytes // GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats() return w.configurer.GetStats(peerKey)
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {
@@ -255,3 +251,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet() return w.tun.GetNet()
} }
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -5,6 +5,7 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@@ -89,3 +90,15 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,6 +1,8 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -13,8 +15,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address netip.Addr address net.IP
dnsAddress netip.Addr dnsAddress net.IP
mtu int mtu int
listenAddress string listenAddress string
@@ -22,7 +24,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@@ -32,9 +34,19 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{t.address}, []netip.Addr{addr.Unmap()},
[]netip.Addr{t.dnsAddress}, []netip.Addr{dnsAddr.Unmap()},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -2,27 +2,28 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net/netip" "net"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP netip.Addr IP net.IP
Network netip.Prefix Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
prefix, err := netip.ParsePrefix(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: prefix.Addr().Unmap(), IP: ip,
Network: prefix.Masked(), Network: network,
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits()) maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -24,8 +24,6 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -51,10 +49,6 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -64,6 +58,9 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@@ -73,16 +70,13 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@@ -95,10 +89,6 @@ UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@@ -114,8 +104,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@@ -125,30 +115,6 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -210,10 +176,10 @@ ${EndIf}
FunctionEnd FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
# SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -259,58 +225,31 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..." # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey( func GenerateRouteRuleKey(
sources []netip.Prefix, sources []netip.Prefix,
destination manager.Network, destination netip.Prefix,
proto manager.Protocol, proto manager.Protocol,
sPort *manager.Port, sPort *manager.Port,
dPort *manager.Port, dPort *manager.Port,

View File

@@ -18,7 +18,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -26,12 +25,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) ApplyFiltering(networkMap *mgmProto.NetworkMap)
}
type protoMatch struct {
ips map[string]int
policyID []byte
} }
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
@@ -54,15 +48,10 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
// //
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains. // If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) { func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -74,9 +63,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total) time.Since(start), total)
}() }()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -170,16 +171,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs d.peerRulesPairs = newRulePairs
} }
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error { func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules)) newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present // Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules { for _, rule := range rules {
id, err := d.applyRouteACL(rule, dynamicResolver) id, err := d.applyRouteACL(rule)
if err != nil { if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) { if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err) log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else { } else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
} }
@@ -202,7 +203,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) { func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 { if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty return "", ErrSourceRangesEmpty
} }
@@ -216,9 +217,15 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
sources = append(sources, source) sources = append(sources, source)
} }
destination, err := determineDestination(rule, dynamicResolver, sources) var destination netip.Prefix
if err != nil { if rule.IsDynamic {
return "", fmt.Errorf("determine destination: %w", err) destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
} }
protocol, err := convertToFirewallProtocol(rule.Protocol) protocol, err := convertToFirewallProtocol(rule.Protocol)
@@ -233,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
dPorts := convertPortInfo(rule.PortInfo) dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action) addedRule, err := d.firewall.AddRouteFiltering(rule.Id, sources, destination, protocol, nil, dPorts, action)
if err != nil { if err != nil {
return "", fmt.Errorf("add route rule: %w", err) return "", fmt.Errorf("add route rule: %w", err)
} }
@@ -282,13 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
var rules []firewall.Rule var rules []firewall.Rule
switch r.Direction { switch r.Direction {
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addInRules(r.Id, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() { // TODO: Remove this soon. Outbound rules are obsolete.
return "", nil, nil // We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
} rules, err = d.addOutRules(r.Id, ip, protocol, port, action, ipsetName)
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
@@ -383,8 +388,10 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
in := map[mgmProto.RuleProtocol]*protoMatch{} type protoMatch map[mgmProto.RuleProtocol]map[string]int
out := map[mgmProto.RuleProtocol]*protoMatch{}
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed // trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{} squashedRules := []*mgmProto.FirewallRule{}
@@ -397,22 +404,14 @@ func (d *DefaultManager) squashAcceptRules(
// 2. Any of rule contains Port. // 2. Any of rule contains Port.
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
r.Port != "" || !portInfoEmpty(r.PortInfo) if drop {
protocols[r.Protocol] = map[string]int{}
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{ protocols[r.Protocol] = map[string]int{}
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
} }
// special case, when we receive this all network IP address // special case, when we receive this all network IP address
@@ -424,7 +423,7 @@ func (d *DefaultManager) squashAcceptRules(
return return
} }
ipset := protocols[r.Protocol].ips ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok { if _, ok := ipset[r.PeerIP]; ok {
return return
@@ -450,10 +449,9 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.RuleProtocol_UDP, mgmProto.RuleProtocol_UDP,
} }
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
match, ok := matches[protocol] if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if : // don't squash if :
// 1. Rules not cover all peers in the network // 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network. // 2. Rules cover only one peer in the network.
@@ -466,7 +464,6 @@ func (d *DefaultManager) squashAcceptRules(
Direction: direction, Direction: direction,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol, Protocol: protocol,
PolicyID: match.policyID,
}) })
squashedProtocols[protocol] = struct{}{} squashedProtocols[protocol] = struct{}{}
@@ -495,9 +492,9 @@ func (d *DefaultManager) squashAcceptRules(
// if we also have other not squashed rules. // if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok { if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
continue continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { } else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
continue continue
} }
} }
@@ -574,33 +571,6 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil return nil
} }
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix { func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0) return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@@ -1,21 +1,21 @@
package acl package acl
import ( import (
"net/netip" "context"
"net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
@@ -43,31 +43,35 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32") ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(), IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err) if err != nil {
defer func() { t.Errorf("create firewall: %v", err)
err = fw.Close(nil) return
require.NoError(t, err) }
}() defer func(fw manager.Manager) {
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap)
if fw.IsStateful() { if len(acl.peerRulesPairs) != 2 {
assert.Equal(t, 0, len(acl.peerRulesPairs)) t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
} else { return
assert.Equal(t, 2, len(acl.peerRulesPairs))
} }
}) })
@@ -89,15 +93,14 @@ func TestDefaultManager(t *testing.T) {
}, },
) )
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap)
expectedRules := 2 // we should have one old and one new rule in the existed rules
if fw.IsStateful() { if len(acl.peerRulesPairs) != 2 {
expectedRules = 1 // only the inbound rule t.Errorf("firewall rules not applied")
return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
@@ -105,86 +108,26 @@ func TestDefaultManager(t *testing.T) {
previousCount++ previousCount++
} }
} }
if previousCount != 1 {
expectedPreviousCount := 0 t.Errorf("old rule was not removed")
if !fw.IsStateful() {
expectedPreviousCount = 1
} }
assert.Equal(t, expectedPreviousCount, previousCount)
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
acl.ApplyFiltering(networkMap, false) if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
assert.Equal(t, 0, len(acl.peerRulesPairs)) t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 1 {
expectedRules := 1 t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
if fw.IsStateful() { return
expectedRules = 1 // only inbound allow-all rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
}) })
} }
@@ -250,19 +193,42 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules)) if len(rules) != 2 {
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0] r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP) switch {
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) case r.PeerIP != "0.0.0.0":
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) return
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1] r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP) switch {
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) case r.PeerIP != "0.0.0.0":
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) return
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
} }
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -326,435 +292,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
} }
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules)) t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
} }
} }
@@ -798,29 +337,33 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32") ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(), IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err) if err != nil {
defer func() { t.Errorf("create firewall: %v", err)
err = fw.Close(nil) return
require.NoError(t, err) }
}() defer func(fw manager.Manager) {
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap)
expectedRules := 3 if len(acl.peerRulesPairs) != 3 {
if fw.IsStateful() { t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
expectedRules = 3 // 2 inbound rules + SSH rule return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
} }

View File

@@ -64,8 +64,13 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { if runtime.GOOS == "linux" && !isLinuxDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@@ -94,22 +94,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
p.codeVerifier = codeVerifier p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier) codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
params := []oauth2.AuthCodeOption{ state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
} )
if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
return AuthFlowInfo{ return AuthFlowInfo{
VerificationURIComplete: authURL, VerificationURIComplete: authURL,

View File

@@ -1,71 +0,0 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
)
func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct {
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
LoginFlag: tc.loginFlag,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
}
authInfo, err := pkce.RequestAuthInfo(context.Background())
if err != nil {
t.Fatalf("Failed to request auth info: %v", err)
}
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More