Compare commits

...

120 Commits

Author SHA1 Message Date
bcmmbaga
2f15708d54 Merge branch 'feature/validate-group-association' into feature/validate-group-association-debug 2024-10-22 17:47:46 +03:00
bcmmbaga
85ffbd1db5 add panic recovery and detailed logging in peer update comparison
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-22 17:46:42 +03:00
bcmmbaga
abdba6c650 Run diff for client posture checks only
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-22 17:35:54 +03:00
bcmmbaga
b025dbeb75 increase Linux test timeout to 10 minutes
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-22 11:19:29 +03:00
bcmmbaga
ffd3cd66f4 Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association 2024-10-21 22:41:22 +03:00
bcmmbaga
964e966269 Merge branch 'main' into feature/optimize-network-map-updates 2024-10-21 22:40:40 +03:00
bcmmbaga
81f69599ea Merge branch 'feature/validate-group-association' into feature/validate-group-association-debug 2024-10-21 22:36:43 +03:00
bcmmbaga
356d3624c4 add route and group tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 22:34:58 +03:00
bcmmbaga
09c9f21a8b add ns group and policy tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 20:35:33 +03:00
bcmmbaga
62899df75d add tests for posture checks changes
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 20:07:45 +03:00
bcmmbaga
eb68e35f44 add tests missing tests for dns setting groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 19:57:56 +03:00
Bethuel Mmbaga
4a9f755b13 Update management/server/route_test.go
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-21 19:02:05 +03:00
Bethuel Mmbaga
f891959d64 Update management/server/route_test.go
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-21 19:01:50 +03:00
Bethuel Mmbaga
6a8cb2b0f9 Update management/server/route_test.go
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-21 19:01:32 +03:00
Bethuel Mmbaga
5df1973caf Update management/server/route_test.go
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-21 19:01:17 +03:00
Bethuel Mmbaga
a6dc54c21a Update management/server/route_test.go
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-21 19:01:02 +03:00
bcmmbaga
9b0424ea55 update account policy check before verifying policy status
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 17:30:31 +03:00
bcmmbaga
ea6d037b17 skip spell check for GroupD
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 17:26:06 +03:00
bcmmbaga
86ce2ed72b Merge branch 'feature/validate-group-association' into feature/validate-group-association-debug 2024-10-21 17:12:42 +03:00
bcmmbaga
070e1dd890 Refactor group, ns group, policy and posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 17:12:19 +03:00
bcmmbaga
8b61ffa78f skip spell check for groupD
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 17:04:52 +03:00
Maycon Santos
9929b22afc Replace suite tests with regular go tests (#2762)
* Replace file suite tests with go tests

* Replace file suite tests with go tests
2024-10-21 14:39:28 +02:00
bcmmbaga
5c0e4097d8 Merge branch 'feature/validate-group-association' into feature/validate-group-association-debug 2024-10-21 15:35:56 +03:00
bcmmbaga
13aa9f7198 refactor peer and user
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 15:35:31 +03:00
bcmmbaga
8a02d3eb9e Merge branch 'feature/validate-group-association' into feature/validate-group-association-debug 2024-10-21 12:37:44 +03:00
bcmmbaga
53218f99bc Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association 2024-10-21 12:37:24 +03:00
bcmmbaga
e5ecf0e5b3 Merge branch 'main' into feature/optimize-network-map-updates
# Conflicts:
#	management/server/peer/peer.go
2024-10-21 12:36:43 +03:00
bcmmbaga
006524756c add trace logs for skip network update
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 12:09:48 +03:00
bcmmbaga
ced28c4376 skip the update only last sent the serial is larger
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-21 11:01:01 +03:00
Maycon Santos
88e4fc2245 Release global lock on early error (#2760) 2024-10-19 18:32:17 +02:00
Maycon Santos
c8d8748dcf Update sign workflow version (#2756) 2024-10-18 17:28:58 +02:00
Maycon Santos
507a40bd7f Fix decompress zip path (#2755)
Since 0.30.2 the decompressed binary path from the signed package has changed

now it doesn't contain the arch suffix

this change handles that
2024-10-17 20:39:59 +02:00
Maycon Santos
ccd4ae6315 Fix domain information is up to date check (#2754) 2024-10-17 19:21:35 +02:00
Bethuel Mmbaga
96d2207684 Fix JSON function compatibility for SQLite and PostgreSQL (#2746)
resolves the issue with json_array_length compatibility between SQLite and PostgreSQL. It adjusts the query to conditionally cast types:

PostgreSQL: Casts to json with ::json.
SQLite: Uses the text representation directly.
2024-10-16 17:55:30 +02:00
Emre Oksum
f942491b91 Update Zitadel version on quickstart script (#2744)
Update Zitadel version at docker compose in quickstart script from 2.54.3 to 2.54.10 because 2.54.3 isn't stable and has a lot of bugs.
2024-10-16 17:51:21 +02:00
Viktor Liu
8c8900be57 [client] Exclude loopback from NAT (#2747) 2024-10-16 17:35:59 +02:00
Maycon Santos
cee95461d1 [client] Add universal bin build and update sign workflow version (#2738)
* Add universal binaries build for macOS

* update sign pipeline version

* handle info.plist in sign workflow
2024-10-15 15:03:17 +02:00
ctrl-zzz
49e65109d2 Add session expire functionality based on inactivity (#2326)
Implemented inactivity expiration by checking the status of a peer: after a configurable period of time following netbird down, the peer shows login required.
2024-10-13 14:52:43 +02:00
Zoltan Papp
d93dd4fc7f [relay-server] Move the handshake logic to separated struct (#2648)
* Move the handshake logic to separated struct

- The server will response to the client after it ready to process the peer
- Preload the response messages

* Fix deprecated lint issue

* Fix error handling

* [relay-server] Relay measure auth time (#2675)

Measure the Relay client's authentication time
2024-10-12 18:21:34 +02:00
Viktor Liu
3a88ac78ff [client] Add table filter rules using iptables (#2727)
This specifically concerns the established/related rule since this one is not compatible with iptables-nft even if it is generated the same way by iptables-translate.
2024-10-12 10:44:48 +02:00
Maycon Santos
da3a053e2b [management] Refactor getAccountIDWithAuthorizationClaims (#2715)
This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance.

- have dedicated methods to handle possible cases
- introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods
- Remove GetAccount and SaveAccount dependency
- added tests
2024-10-12 08:35:51 +02:00
Zoltan Papp
0e95f16cdd [relay,client] Relay/fix/wg roaming (#2691)
If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature.

Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources.
Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe.
2024-10-11 16:24:30 +02:00
bcmmbaga
cd92646348 enable diff nil structs comparison
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-10 19:33:17 +03:00
bcmmbaga
30a0d9c8c4 fix postgres tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-10 18:03:00 +03:00
bcmmbaga
a42ebb8202 fix management suite tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-10 15:26:05 +03:00
bcmmbaga
15b83cb1e5 Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association 2024-10-10 13:59:02 +03:00
bcmmbaga
fdb1a1fe00 Merge branch 'main' into feature/optimize-network-map-updates 2024-10-10 13:57:34 +03:00
bcmmbaga
8cabb07728 fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-10 10:40:24 +03:00
bcmmbaga
57f7f43ecb Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association
# Conflicts:
#	management/server/account.go
2024-10-10 09:46:32 +03:00
bcmmbaga
2e20a586cb fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-10 09:44:31 +03:00
bcmmbaga
ed3c3c214e Merge branch 'main' into feature/optimize-network-map-updates
# Conflicts:
#	management/server/testdata/store.json
2024-10-10 09:31:55 +03:00
bcmmbaga
bdf114cd74 add peer tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-10 01:51:47 +03:00
bcmmbaga
6d985c5991 go mod tidy
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-09 23:54:52 +03:00
bcmmbaga
ce7de03d6e use generic differ for netip.Addr and netip.Prefix
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-09 23:49:41 +03:00
bcmmbaga
9ee08fc441 fix nameserver tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-09 17:04:20 +03:00
bcmmbaga
271bed5f73 upgrade diff package
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-09 17:01:40 +03:00
bcmmbaga
2a751645f9 fix group tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-09 17:00:29 +03:00
bcmmbaga
d4edde90c2 fix routes tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-09 13:52:43 +03:00
bcmmbaga
5cc07ba42a fix nameserver tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-08 19:39:12 +03:00
bcmmbaga
70f1c394c1 fix typo
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-08 17:15:06 +03:00
bcmmbaga
c74a13e1a9 fix account and route tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-08 17:11:50 +03:00
bcmmbaga
1ed44b810c fix user and setup key tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-08 16:51:00 +03:00
bcmmbaga
41acacfba5 add posture checks tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-08 15:50:13 +03:00
bcmmbaga
fc7157f82f add policy tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-08 11:59:03 +03:00
bcmmbaga
63c510e80d fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-06 20:34:50 +03:00
bcmmbaga
716009b791 Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association
# Conflicts:
#	management/server/account.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/policy.go
#	management/server/route.go
#	management/server/route_test.go
2024-10-04 10:46:41 +03:00
bcmmbaga
a915707d13 fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-03 14:12:53 +03:00
bcmmbaga
5108888163 Merge branch 'main' into feature/optimize-network-map-updates
# Conflicts:
#	management/server/account_test.go
#	management/server/peer.go
2024-10-03 14:10:46 +03:00
bcmmbaga
4e2cf9c63a fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-06 12:15:13 +03:00
bcmmbaga
5dbdeff77a Simplify peer update condition in DNS management
Refactor the condition for updating account peers to remove redundant checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-05 21:29:14 +03:00
bcmmbaga
7523a9e7be Refactor posture check policy linking logic
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-05 19:36:20 +03:00
bcmmbaga
75ab35563a Update route check by checking if group has peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-05 19:20:09 +03:00
bcmmbaga
c6650705a1 Refactor policy group handling and update logic.
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-05 16:36:11 +03:00
bcmmbaga
f29f8c009f fix tests 2024-08-16 12:24:06 +03:00
bcmmbaga
8826196503 fix linter 2024-08-15 11:19:10 +03:00
bcmmbaga
ca8565de1f Refactor duplicate diff handling logic 2024-08-15 10:57:45 +03:00
bcmmbaga
ac06346f5c Add tests 2024-08-15 10:50:03 +03:00
bcmmbaga
151969bdd7 Update network map diff logic with custom comparators 2024-08-15 10:49:45 +03:00
bcmmbaga
441136e2c6 Add NameServer and Route comparators 2024-08-15 10:49:11 +03:00
bcmmbaga
ff19b237d9 Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association 2024-08-14 14:46:44 +03:00
bcmmbaga
376ded1b2f go mod tidy 2024-08-14 14:45:23 +03:00
bcmmbaga
73b9e1c926 Merge branch 'main' into feature/optimize-network-map-updates
# Conflicts:
#	go.sum
2024-08-14 14:45:11 +03:00
bcmmbaga
fb627a308c go mod tidy 2024-08-14 13:31:25 +03:00
bcmmbaga
c918bab09a Merge branch 'main' into feature/validate-group-association
# Conflicts:
#	go.sum
2024-08-14 13:30:54 +03:00
bcmmbaga
7fa71419cd Fix tests 2024-08-14 10:27:07 +03:00
bcmmbaga
226dc95afa fix merge 2024-08-13 22:03:46 +03:00
bcmmbaga
1548542df3 Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association
# Conflicts:
#	management/server/dns_test.go
#	management/server/group.go
#	management/server/nameserver.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/user.go
2024-08-13 16:30:04 +03:00
bcmmbaga
34114d4a55 Fix peers update by including NetworkMap and posture Checks 2024-08-12 18:01:07 +03:00
bcmmbaga
f3ec200985 Merge branch 'main' into feature/optimize-network-map-updates 2024-08-12 13:49:02 +03:00
bcmmbaga
8ecbe675a1 Merge branch 'main' into feature/optimize-network-map-updates
# Conflicts:
#	management/server/peer.go
2024-08-09 10:55:31 +03:00
bcmmbaga
f0d91bcfc4 Add tests for peer update behavior on peers changes 2024-07-31 02:45:09 +03:00
bcmmbaga
eb9aadfd38 Add tests for peer update behavior on setup key changes 2024-07-31 01:31:05 +03:00
bcmmbaga
8bab9dc3c0 fix tests 2024-07-31 01:24:02 +03:00
bcmmbaga
02c0a9b1da Add tests for peer update behavior on route changes 2024-07-31 01:18:37 +03:00
bcmmbaga
c76cd1d86e Add tests for peer update behavior on user changes 2024-07-31 01:18:14 +03:00
bcmmbaga
d990d95236 Add tests for peer update behavior on name server changes 2024-07-30 18:36:33 +03:00
bcmmbaga
cf211f6337 Refactor 2024-07-30 17:15:47 +03:00
bcmmbaga
8d9ea40bf1 Add tests for peer update behavior on dns settings changes 2024-07-30 16:38:32 +03:00
bcmmbaga
7647701898 Add tests for peer update behavior on group changes 2024-07-30 16:01:11 +03:00
bcmmbaga
6554b26600 Add tests for peer update behavior on policy changes 2024-07-30 14:56:23 +03:00
bcmmbaga
8455455142 Add tests for peer update behavior on posture check changes 2024-07-29 21:46:50 +03:00
bcmmbaga
c48f244bee Remove unused isPolicyRuleGroupsEmpty 2024-07-26 17:47:02 +03:00
bcmmbaga
b7fcd0d753 Remove UpdatePeerSSHKey method 2024-07-23 21:16:25 +03:00
bcmmbaga
a19c2f660c Merge branch 'refs/heads/feature/optimize-network-map-updates' into feature/validate-group-association 2024-07-22 15:24:30 +03:00
bcmmbaga
936215b395 Optimize account peers updates on route changes 2024-07-22 13:51:18 +03:00
bcmmbaga
bb08adcbac Remove condition check for network serial update 2024-07-20 20:36:36 +03:00
bcmmbaga
f5ec234f09 Optimize peer update on user deletion and changes 2024-07-20 20:08:29 +03:00
bcmmbaga
26f089e30d Refactor peer account updates for efficiency 2024-07-20 12:37:25 +03:00
bcmmbaga
713c0341be Optimize update of account peers on jwt groups sync 2024-07-19 14:09:33 +03:00
bcmmbaga
1bbd8ae4b0 Optimize account peers update in DNS settings 2024-07-19 10:51:20 +03:00
bcmmbaga
a723c424f0 Refactor group changes 2024-07-19 10:51:05 +03:00
bcmmbaga
3e76deaa87 Update account peers if ns group has peers 2024-07-18 21:04:58 +03:00
bcmmbaga
36d4c21671 Optimize group change effects on account peers 2024-07-18 20:37:29 +03:00
bcmmbaga
181e8648a8 Refactor group management 2024-07-18 19:59:37 +03:00
bcmmbaga
1012c2f990 Add HasPeers function to group 2024-07-18 19:59:14 +03:00
bcmmbaga
1b28d1dfbc Refactor group link checking into re-usable functions 2024-07-18 16:41:21 +03:00
Bethuel Mmbaga
f17016b5e5 Skip peer update on unchanged network map (#2236)
* Enhance network updates by skipping unchanged messages

Optimizes the network update process
by skipping updates where no changes in the peer update message received.

* Add unit tests

* add locks

* Improve concurrency and update peer message handling

* Refactor account manager network update tests

* fix test

* Fix inverted network map update condition

* Add default group and policy to test data

* Run peer updates in a separate goroutine

* Refactor

* Refactor lock
2024-07-18 13:50:44 +03:00
bcmmbaga
b6cef2ce2c Remove account peers update on saving setup key 2024-07-16 18:30:47 +03:00
bcmmbaga
dedf13d8f1 Update account peer if posture check is linked to policy 2024-07-16 18:19:05 +03:00
bcmmbaga
d676c41c74 Remove incrementing network serial and updating peers after group deletion 2024-07-16 16:50:44 +03:00
64 changed files with 5482 additions and 1136 deletions

View File

@@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif
ignore_words_list: erro,clienta,hastable,iif,groupd
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.14"
SIGN_PIPE_VER: "v0.0.16"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -223,4 +223,4 @@ jobs:
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'

View File

@@ -96,6 +96,9 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
archives:
- builds:
- netbird

View File

@@ -23,6 +23,9 @@ builds:
tags:
- load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives:
- builds:
- netbird-ui-darwin

View File

@@ -433,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {

View File

@@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
Exprs: getEstablishedExprs(1),
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},

View File

@@ -10,6 +10,7 @@ import (
"net/netip"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
@@ -81,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
}
}
err = r.cleanUpDefaultForwardRules()
err = r.removeAcceptForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
@@ -98,40 +99,7 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.cleanUpDefaultForwardRules()
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
return r.conn.Flush()
return r.removeAcceptForwardRules()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -167,7 +135,9 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
r.acceptForwardRules()
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -455,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME
notDir := expr.MetaKeyOIFNAME
if pair.Inverse {
dir = expr.MetaKeyOIFNAME
notDir = expr.MetaKeyIIFNAME
}
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
&expr.Meta{
Key: dir,
@@ -470,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Register: 1,
Data: intf,
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: lo,
},
}
exprs = append(exprs, sourceExp...)
@@ -577,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() {
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
@@ -609,6 +635,15 @@ func (r *router) acceptForwardRules() {
}
r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
@@ -619,36 +654,72 @@ func (r *router) acceptForwardRules() {
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 2,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes a nftables rule pair from nat chains

View File

@@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
@@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))

View File

@@ -82,8 +82,6 @@ type Conn struct {
config ConnConfig
statusRecorder *Status
wgProxyFactory *wgproxy.Factory
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
@@ -106,7 +104,8 @@ type Conn struct {
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
endpointRelay *net.UDPAddr
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
// for reconnection operations
iCEDisconnected chan bool
@@ -257,8 +256,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil
}
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if err != nil {
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
@@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
if conn.currentConnPriority > priority {
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
return
}
conn.log.Infof("set ICE to active connection")
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
if err != nil {
return
var (
ep *net.UDPAddr
wgProxy wgproxy.Proxy
err error
)
if iceConnInfo.RelayedOnLocal {
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return
}
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
return
}
ep = directEp
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher()
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if wgProxy != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close turn connection: %v", err)
}
}
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
if wgProxy != nil {
wgProxy.Work()
}
if err = conn.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
wgConfigWorkaround()
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyICE = wgProxy
conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}
@@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState)
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
// switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil {
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
@@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
select {
case conn.iCEDisconnected <- changed:
default:
}
conn.notifyReconnectLoopICEDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay.Set(StatusConnected)
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
}
conn.connIDRelay = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
}
@@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return
}
log.Debugf("relay connection is disconnected")
conn.log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay {
log.Debugf("clean up WireGuard config")
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if err != nil {
conn.log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
}
if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil
}
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
select {
case conn.relayDisconnected <- changed:
default:
}
conn.notifyReconnectLoopRelayDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
if err != nil {
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
}
}
@@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
return true
}
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
@@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
}
}
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
if !iceConnInfo.RelayedOnLocal {
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection")
wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil {
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
}
return nil, nil, err
return nil, err
}
return wgProxy, nil
}
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
}
func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
}
}
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Work()
}
return ep, wgProxy, nil
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {

View File

@@ -5,7 +5,6 @@ package ebpf
import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result)
}
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data)

View File

@@ -4,8 +4,13 @@ package ebpf
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
return fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return 0, ctx.Err()
}
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}
return n, nil
}

View File

@@ -7,6 +7,9 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
AddTurnConn(ctx context.Context, turnConn net.Conn) error
EndpointAddr() *net.UDPAddr
Work()
Pause()
CloseConn() error
}

View File

@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}

View File

@@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
// AddTurnConn
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
return err
}
go p.proxyToRemote()
go p.proxyToLocal()
p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn
p.remoteConn = remoteConn
return p.localConn.LocalAddr(), err
return err
}
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
return endpointUdpAddr
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the localConn
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
for ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
// if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
for {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)

View File

@@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleExecutable</key>
<string>netbird-ui</string>
<key>CFBundleIconFile</key>
<string>Netbird</string>
<key>LSUIElement</key>
<string>1</string>
</dict>
</plist>

View File

@@ -8,11 +8,11 @@ cask "{{ $projectName }}" do
if Hardware::CPU.intel?
url "{{ $amdURL }}"
sha256 "{{ crypto.SHA256 $amdFileBytes }}"
app "netbird_ui_darwin_amd64", target: "Netbird UI.app"
app "netbird_ui_darwin", target: "Netbird UI.app"
else
url "{{ $armURL }}"
sha256 "{{ crypto.SHA256 $armFileBytes }}"
app "netbird_ui_darwin_arm64", target: "Netbird UI.app"
app "netbird_ui_darwin", target: "Netbird UI.app"
end
depends_on formula: "netbird"
@@ -36,4 +36,4 @@ cask "{{ $projectName }}" do
name "Netbird UI"
desc "Netbird UI Client"
homepage "https://www.netbird.io/"
end
end

3
go.mod
View File

@@ -71,6 +71,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/r3labs/diff/v3 v3.0.1
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@@ -210,6 +211,8 @@ require (
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect

6
go.sum
View File

@@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

View File

@@ -873,7 +873,7 @@ services:
zitadel:
restart: 'always'
networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file:
- ./zitadel.env

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"errors"
"fmt"
"hash/crc32"
"math/rand"
@@ -44,12 +45,15 @@ import (
)
const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
DefaultPeerInactivityExpiration = 10 * time.Minute
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
)
type userLoggedInOnce bool
@@ -98,7 +102,6 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@@ -178,6 +181,8 @@ type DefaultAccountManager struct {
dnsDomain string
peerLoginExpiry Scheduler
peerInactivityExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool
@@ -195,6 +200,13 @@ type Settings struct {
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
PeerInactivityExpirationEnabled bool
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
PeerInactivityExpiration time.Duration
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
RegularUsersViewBlocked bool
@@ -225,6 +237,9 @@ func (s *Settings) Copy() *Settings {
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: s.PeerInactivityExpiration,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
@@ -606,6 +621,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
return peers
}
// GetInactivePeers returns peers that have been expired by inactivity
func (a *Account) GetInactivePeers() []*nbpeer.Peer {
var peers []*nbpeer.Peer
for _, inactivePeer := range a.GetPeersWithInactivity() {
inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) {
peersWithExpiry := a.GetPeersWithInactivity()
if len(peersWithExpiry) == 0 {
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user
func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer {
peers := make([]*nbpeer.Peer, 0)
for _, peer := range a.Peers {
if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() {
peers = append(peers, peer)
}
}
return peers
}
// GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*nbpeer.Peer {
var peers []*nbpeer.Peer
@@ -972,6 +1041,7 @@ func BuildManager(
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
@@ -1100,6 +1170,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
@@ -1110,6 +1185,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
return nil
}
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@@ -1145,6 +1240,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context
}
}
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id)
return account.GetNextInactivePeerExpiration()
}
expiredPeers := account.GetInactivePeers()
var peerIDs []string
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextInactivePeerExpiration()
}
return account.GetNextInactivePeerExpiration()
}
}
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
}
}
// newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
@@ -1285,7 +1417,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
return "", err
}
return account.Id, nil
@@ -1300,28 +1432,39 @@ func isNil(i idp.Manager) bool {
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
if err != nil {
return err
}
cachedAccount := &Account{
Id: accountID,
Users: make(map[string]*User),
}
for _, user := range accountUsers {
cachedAccount.Users[user.Id] = user
}
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, account)
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == account.Id {
if user != nil && user.AppMetadata.WTAccountID == accountID {
// it was already set, so we skip the unnecessary update
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID)
accountID, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
return err
}
@@ -1545,48 +1688,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
if claims.Domain != "" {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
}
} else {
if claims.Domain == "" {
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
return nil
}
err := am.Store.SaveAccount(ctx, account)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err
}
return nil
if domainIsUpToDate(accountDomain, domainCategory, claims) {
return nil
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
}
newDomain := accountDomain
newCategoty := domainCategory
lowerDomain := strings.ToLower(claims.Domain)
if accountDomain != lowerDomain && user.HasAdminPower() {
newDomain = lowerDomain
}
if accountDomain == lowerDomain {
newCategoty = claims.DomainCategory
}
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context,
existingAcc *Account,
primaryDomain bool,
userAccountID string,
domainAccountID string,
claims jwtclaims.AuthorizationClaims,
) error {
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
if err != nil {
return err
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
if err != nil {
return err
}
@@ -1594,44 +1758,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
return nil
}
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
return "", fmt.Errorf("user ID is empty")
}
var (
account *Account
err error
)
lowerDomain := strings.ToLower(claims.Domain)
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
} else {
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
if err != nil {
return nil, err
}
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
if err != nil {
return nil, err
}
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
if err != nil {
return nil, err
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
newAccount.Domain = lowerDomain
newAccount.DomainCategory = claims.DomainCategory
newAccount.IsDomainPrimaryAccount = true
return account, nil
err = am.Store.SaveAccount(ctx, newAccount)
if err != nil {
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
if err != nil {
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
return newAccount.Id, nil
}
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
usersMap := make(map[string]*User)
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
err := am.Store.SaveUsers(domainAccountID, usersMap)
if err != nil {
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
if err != nil {
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
}
// redeemInvite checks whether user has been invited and redeems the invite
@@ -1775,7 +1953,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
return "", "", fmt.Errorf("user ID is empty")
return "", "", errors.New(emptyUserID)
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
@@ -1953,24 +2131,29 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for account: %s", accountID)
}
}
return nil
}
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
// if domain is not private or domain is invalid, it will return the account ID by user ID.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
// Use cases:
//
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
//
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
//
// New user + New account + Existing Public Domain -> create account, user role = admin
// New user + New account + Existing Public Domain -> create account, user role = owner
//
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
//
@@ -1980,98 +2163,124 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" {
return "", fmt.Errorf("user ID is empty")
return "", errors.New(emptyUserID)
}
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
if claims.AccountId != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
}
return claims.AccountId, nil
}
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
return "", err
}
if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
return userAccountID, nil
}
}
start := time.Now()
unlock := am.Store.AcquireGlobalLock(ctx)
defer unlock()
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
if claims.AccountId != "" {
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
}
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
if cancel != nil {
defer cancel()
}
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return "", err
}
return "", err
}
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
defer unlockAccount()
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
return "", err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
return "", err
}
return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
var domainAccount *Account
if domainAccountID != "" {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
return "", err
}
}
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else {
// other error
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
if userAccountID != "" {
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
return "", err
}
return userAccountID, nil
}
if domainAccountID != "" {
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
}
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
}
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", nil, err
}
if domainAccountID != "" {
return domainAccountID, nil, nil
}
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil {
cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", nil, err
}
return domainAccountID, cancel, nil
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err
}
if domainIsUpToDate(accountDomain, domainCategory, claims) {
return claims.AccountId, nil
}
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err
}
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
if err != nil {
return "", err
}
return claims.AccountId, nil
}
func handleNotFound(err error) error {
if err == nil {
return nil
}
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return err
}
return nil
}
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
@@ -2337,6 +2546,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
PeerLoginExpiration: DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -139,6 +139,13 @@ const (
PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted Activity = 62
PeerInactivityExpirationEnabled Activity = 63
PeerInactivityExpirationDisabled Activity = 64
AccountPeerInactivityExpirationEnabled Activity = 65
AccountPeerInactivityExpirationDisabled Activity = 66
AccountPeerInactivityExpirationDurationUpdated Activity = 67
)
var activityMap = map[Activity]Code{
@@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
PostureCheckCreated: {"Posture check created", "posture.check.created"},
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
}
// StringCode returns a string code of the activity

View File

@@ -0,0 +1,82 @@
package differs
import (
"fmt"
"net/netip"
"reflect"
"github.com/r3labs/diff/v3"
)
// NetIPAddr is a custom differ for netip.Addr
type NetIPAddr struct {
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
}
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
}
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
if a.Kind() == reflect.Invalid {
cl.Add(diff.CREATE, path, nil, b.Interface())
return nil
}
if b.Kind() == reflect.Invalid {
cl.Add(diff.DELETE, path, a.Interface(), nil)
return nil
}
fromAddr, ok1 := a.Interface().(netip.Addr)
toAddr, ok2 := b.Interface().(netip.Addr)
if !ok1 || !ok2 {
return fmt.Errorf("invalid type for netip.Addr")
}
if fromAddr.String() != toAddr.String() {
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
}
return nil
}
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
differ.DiffFunc = dfunc //nolint
}
// NetIPPrefix is a custom differ for netip.Prefix
type NetIPPrefix struct {
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
}
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
}
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
if a.Kind() == reflect.Invalid {
cl.Add(diff.CREATE, path, nil, b.Interface())
return nil
}
if b.Kind() == reflect.Invalid {
cl.Add(diff.DELETE, path, a.Interface(), nil)
return nil
}
fromPrefix, ok1 := a.Interface().(netip.Prefix)
toPrefix, ok2 := b.Interface().(netip.Prefix)
if !ok1 || !ok2 {
return fmt.Errorf("invalid type for netip.Addr")
}
if fromPrefix.String() != toPrefix.String() {
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
}
return nil
}
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
differ.DiffFunc = dfunc //nolint
}

View File

@@ -125,26 +125,31 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
am.updateAccountPeers(ctx, account)
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for account dns settings: %s", accountID)
}
return nil
}

View File

@@ -6,9 +6,11 @@ import (
"net/netip"
"reflect"
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func TestDNSAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
t.Run("saving dns setting with unused groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupA"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
// Saving DNS settings with groups that have peers should update account peers and send peer update
t.Run("saving dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupA", "groupB"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
// since there is no change in the network map
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupA", "groupB"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupA"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing group with peers from DNS settings should trigger updates to account peers and send peer updates
t.Run("removing group with peers from dns settings", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
account.Settings = &Settings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: DefaultPeerLoginExpiration,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
}
}

View File

@@ -121,12 +121,21 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
eventsToStore = append(eventsToStore, events...)
}
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(ctx, account)
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for group: %v", newGroupIDs)
}
for _, storeEvent := range eventsToStore {
storeEvent()
@@ -238,8 +247,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
am.updateAccountPeers(ctx, account)
return nil
}
@@ -282,8 +289,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
}
am.updateAccountPeers(ctx, account)
return allErrors
}
@@ -336,7 +341,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
am.updateAccountPeers(ctx, account)
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for group: %s", groupID)
}
return nil
}
@@ -366,7 +375,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
}
am.updateAccountPeers(ctx, account)
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for group: %s", groupID)
}
return nil
}
@@ -469,3 +482,32 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}

View File

@@ -44,3 +44,8 @@ func (g *Group) Copy() *Group {
copy(group.Peers, g.Peers)
return group
}
// HasPeers checks if the group has any peers.
func (g *Group) HasPeers() bool {
return len(g.Peers) > 0
}

View File

@@ -4,13 +4,16 @@ import (
"context"
"errors"
"fmt"
"net/netip"
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
}
return am, acc, nil
}
func TestGroupAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
},
{
ID: "groupD",
Name: "GroupD",
Peers: []string{},
},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving a group that is not linked to any resource should not update account peers
t.Run("saving unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Adding a peer to a group that is not linked to any resource should not update account peers
// and not send peer update
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing a peer from a group that is not linked to any resource should not update account peers
// and not send peer update
t.Run("removing peer from unliked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting group should not update account peers and not send peer update
t.Run("deleting group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB")
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}, false)
assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update
t.Run("saving linked group to policy", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving an unchanged group should trigger account peers update and not send peer update
// since there is no change in the network map
t.Run("saving unchanged group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// adding peer to a used group should update account peers and send peer update
t.Run("adding peer to linked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// removing peer from a linked group should update account peers and send peer update
t.Run("removing peer from linked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving a group linked to name server group should update account peers and send peer update
t.Run("saving group linked to name server group", func(t *testing.T) {
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupC"},
true, nil, true, userID, false,
)
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving a group linked to route should update account peers and send peer update
t.Run("saving group linked to route", func(t *testing.T) {
newRoute := route.Route{
ID: "route",
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superNet",
NetworkType: route.IPv4Network,
PeerGroups: []string{"groupA"},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{"groupC"},
}
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving a group linked to dns settings should update account peers and send peer update
t.Run("saving group linked to dns settings", func(t *testing.T) {
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupD"},
})
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
}
if req.Settings.Extra != nil {

View File

@@ -54,6 +54,14 @@ components:
description: Period of time after which peer login expires (seconds).
type: integer
example: 43200
peer_inactivity_expiration_enabled:
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
type: boolean
example: true
peer_inactivity_expiration:
description: Period of time of inactivity after which peer session expires (seconds).
type: integer
example: 43200
regular_users_view_blocked:
description: Allows blocking regular users from viewing parts of the system.
type: boolean
@@ -81,6 +89,8 @@ components:
required:
- peer_login_expiration_enabled
- peer_login_expiration
- peer_inactivity_expiration_enabled
- peer_inactivity_expiration
- regular_users_view_blocked
AccountExtraSettings:
type: object
@@ -243,6 +253,9 @@ components:
login_expiration_enabled:
type: boolean
example: false
inactivity_expiration_enabled:
type: boolean
example: false
approval_required:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
@@ -251,6 +264,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
- inactivity_expiration_enabled
Peer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
@@ -327,6 +341,10 @@ components:
type: string
format: date-time
example: "2023-05-05T09:00:35.477782Z"
inactivity_expiration_enabled:
description: Indicates whether peer inactivity expiration has been enabled or not
type: boolean
example: false
approval_required:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
@@ -354,6 +372,7 @@ components:
- last_seen
- login_expiration_enabled
- login_expired
- inactivity_expiration_enabled
- os
- ssh_enabled
- user_id

View File

@@ -220,6 +220,12 @@ type AccountSettings struct {
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
// PeerLoginExpiration Period of time after which peer login expires (seconds).
PeerLoginExpiration int `json:"peer_login_expiration"`
@@ -538,6 +544,9 @@ type Peer struct {
// Id Peer ID
Id string `json:"id"`
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address
Ip string `json:"ip"`
@@ -613,6 +622,9 @@ type PeerBatch struct {
// Id Peer ID
Id string `json:"id"`
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address
Ip string `json:"ip"`
@@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest.
type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
}
// PersonalAccessToken defines model for PersonalAccessToken.

View File

@@ -7,6 +7,8 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -14,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
}
if req.ApprovalRequired != nil {
@@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
}
return &api.Peer{
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
}
}
@@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
}
}

View File

@@ -57,7 +57,6 @@ type MockAccountManager struct {
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -434,14 +433,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
}
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if am.UpdatePeerSSHKeyFunc != nil {
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
}
// UpdatePeer mocks UpdatePeerFunc function of the account manager
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if am.UpdatePeerFunc != nil {

View File

@@ -8,6 +8,7 @@ import (
"github.com/miekg/dns"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
@@ -66,13 +67,15 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
if err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
am.updateAccountPeers(ctx, account)
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for ns group: %s", newNSGroup.ID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil
@@ -80,7 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -98,16 +100,19 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
if err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(ctx, account)
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for ns group: %s", nsGroupToSave.ID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil
@@ -131,13 +136,15 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
if err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(ctx, account)
if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", nsGroupID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil
@@ -277,3 +284,11 @@ func validateDomain(domain string) error {
return nil
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
@@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) {
}
}
func TestNameServerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a nameserver group with a distribution group no peers should not update account peers
// and not send peer update
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
newNameServerGroupA, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// saving a nameserver group with a distribution group with no peers should not update account peers
// and not send peer update
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Creating a nameserver group with a distribution group no peers should update account peers and send peer update
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
newNameServerGroupB, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupB"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// saving a nameserver group with a distribution group with peers should update account peers and send peer update
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
newNameServerGroupB.NameServers = []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
}
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// saving unchanged nameserver group should update account peers and not send peer update
t.Run("saving unchanged nameserver group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
newNameServerGroupB.NameServers = []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
}
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting a nameserver group should update account peers and send peer update
t.Run("deleting nameserver group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -41,9 +41,9 @@ type Network struct {
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64
Serial uint64 `diff:"-"`
mu sync.Mutex `json:"-" gorm:"-"`
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"slices"
"strings"
"sync"
"time"
@@ -110,6 +111,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return err
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
if err != nil {
return err
}
if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
}
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
}
return nil
}
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
@@ -138,25 +164,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil {
return err
return false, err
}
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if oldStatus.LoginExpired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
}
return nil
return oldStatus.LoginExpired, nil
}
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -185,7 +201,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
if peer.Name != update.Name {
peerLabelUpdated := peer.Name != update.Name
if peerLabelUpdated {
peer.Name = update.Name
existingLabels := account.getPeerDNSLabels()
@@ -219,6 +237,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
event := activity.PeerInactivityExpirationEnabled
if !update.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
}
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
@@ -226,7 +263,11 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err
}
am.updateAccountPeers(ctx, account)
if peerLabelUpdated {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", update.ID)
}
return peer, nil
}
@@ -270,6 +311,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
@@ -288,6 +330,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers := isPeerInActiveGroup(account, peerID)
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
return err
@@ -298,7 +342,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", peerID)
}
return nil
}
@@ -388,9 +436,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string
var setupKeyName string
var ephemeral bool
@@ -442,23 +490,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
InactivityExpirationEnabled: addedByUser,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
@@ -541,7 +590,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for peer: %s", newPeer.ID)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
@@ -862,51 +915,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
}
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
if peer.SSHKey == sshKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
// trigger network map update
am.updateAccountPeers(ctx, account)
return nil
}
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@@ -999,7 +1007,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
@@ -1013,3 +1021,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
}
return labelMap
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
}

View File

@@ -17,35 +17,37 @@ type Peer struct {
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
SetupKey string `diff:"-"`
// IP address of the Peer
IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
// The user ID that registered the peer
UserID string
UserID string `diff:"-"`
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
LoginExpirationEnabled bool `diff:"-"`
InactivityExpirationEnabled bool `diff:"-"`
// LastLogin the time when peer performed last login operation
LastLogin time.Time
LastLogin time.Time `diff:"-"`
// CreatedAt records the time the peer was created
CreatedAt time.Time
CreatedAt time.Time `diff:"-"`
// Indicate ephemeral peer attribute
Ephemeral bool
Ephemeral bool `diff:"-"`
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"`
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
}
type PeerStatus struct { //nolint:revive
@@ -187,6 +189,7 @@ func (p *Peer) Copy() *Peer {
CreatedAt: p.CreatedAt,
Ephemeral: p.Ephemeral,
Location: p.Location,
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
}
}
@@ -219,6 +222,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
p.Status = newStatus
}
// SessionExpired indicates whether the peer's session has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
return false, 0
}
expiresAt := p.Status.LastSeen.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).

View File

@@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
}
}
func TestPeer_SessionExpired(t *testing.T) {
tt := []struct {
name string
expirationEnabled bool
lastLogin time.Time
connected bool
expected bool
accountSettings *Settings
}{
{
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
expirationEnabled: false,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
},
expected: false,
},
{
name: "Peer Inactivity Should Expire",
expirationEnabled: true,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
expected: true,
},
{
name: "Peer Inactivity Should Not Expire",
expirationEnabled: true,
connected: true,
lastLogin: time.Now().UTC(),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
expected: false,
},
}
for _, c := range tt {
t.Run(c.name, func(t *testing.T) {
peerStatus := &nbpeer.PeerStatus{
Connected: c.connected,
}
peer := &nbpeer.Peer{
InactivityExpirationEnabled: c.expirationEnabled,
LastLogin: c.lastLogin,
Status: peerStatus,
UserID: userID,
}
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
assert.Equal(t, expired, c.expected)
})
}
}
func TestAccountManager_GetNetworkMap(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -1191,3 +1253,322 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}
func TestPeerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err)
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{},
},
})
require.NoError(t, err)
// create a user with auto groups
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{
{
Id: "regularUser1",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
AutoGroups: []string{"groupA"},
},
{
Id: "regularUser2",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
AutoGroups: []string{"groupB"},
},
{
Id: "regularUser3",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
AutoGroups: []string{"groupC"},
},
}, true)
require.NoError(t, err)
var peer4 *nbpeer.Peer
var peer5 *nbpeer.Peer
var peer6 *nbpeer.Peer
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Adding peer to unlinked group should not update account peers and not send peer update
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting peer with unlinked group should not update account peers and not send peer update
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Updating peer label should update account peers and send peer update
t.Run("updating peer label", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
peer1.Name = "peer-1"
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}, false)
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting peer with linked group to policy should update account peers and send peer update
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Adding peer to group linked with route should update account peers and send peer update
t.Run("adding peer to group linked with route", func(t *testing.T) {
route := nbroute.Route{
ID: "testingRoute1",
Network: netip.MustParsePrefix("100.65.250.202/32"),
NetID: "superNet",
NetworkType: nbroute.IPv4Network,
PeerGroups: []string{"groupB"},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{"groupB"},
}
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting peer with linked group to route should update account peers and send peer update
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Adding peer to group linked with name server group should update account peers and send peer update
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupC"},
true, []string{}, true, userID, false,
)
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting peer with linked group to name server group should update account peers and send peer update
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
}
}
// ruleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) ruleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
@@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
if err = am.savePolicy(account, policy, isUpdate); err != nil {
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
if err != nil {
return err
}
@@ -363,7 +376,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for policy: %s", policy.ID)
}
return nil
}
@@ -428,7 +445,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
@@ -442,18 +459,25 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
oldPolicy := account.Policies[policyIdx]
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return nil
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil
}
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil
}
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return nil
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {

View File

@@ -5,7 +5,9 @@ import (
"fmt"
"net"
"testing"
"time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
@@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int {
return 0 // a is equal to b
}
}
func TestPolicyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{},
},
{
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
},
})
assert.NoError(t, err)
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
})
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
policy := Policy{
ID: "policy-rule-groups-no-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Saving policy with source group containing peers, but destination group without peers should
// update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
policy := Policy{
ID: "policy-source-has-peers-destination-none",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
Protocol: PolicyRuleProtocolTCP,
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
policy := Policy{
ID: "policy-destination-has-peers-source-none",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: false,
Sources: []string{"groupC"},
Destinations: []string{"groupD"},
Bidirectional: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg2)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Disabling policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Description: "updated description",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Enabling policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving unchanged policy should trigger account peers update but not send peer update
t.Run("saving unchanged policy", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting policy should trigger account peers update and send peer update
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
policyID := "policy-source-destination-peers"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
policyID := "policy-destination-has-peers-source-none"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg2)
close(done)
}()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting policy with no peers in groups should not update account's peers and not send peer update
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}

View File

@@ -8,6 +8,7 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
@@ -67,8 +68,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if exists {
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for posture checks: %s", postureChecks.ID)
}
return nil
@@ -148,13 +152,9 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// check policy links
for _, policy := range account.Policies {
for _, id := range policy.SourcePostureChecks {
if id == postureChecksID {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
}
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
@@ -217,3 +217,25 @@ func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil
}
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
if !exists {
return false
}
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
if !isLinked {
return false
}
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
}

View File

@@ -3,7 +3,10 @@ package server
import (
"context"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/posture"
@@ -118,3 +121,413 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
return am.Store.GetAccount(context.Background(), account.Id)
}
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{},
},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
postureCheck := posture.Checks{
ID: "postureCheck",
Name: "postureCheck",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.28.0",
},
},
}
// Saving unused posture check should not update account peers and not send peer update
t.Run("saving unused posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Updating unused posture check should not update account peers and not send peer update
t.Run("updating unused posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
policy := Policy{
ID: "policyA",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
// Linking posture check to policy should trigger update account peers and send peer update
t.Run("linking posture check to policy with peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Updating linked posture checks should update account peers and send peer update
t.Run("updating linked to posture check with peers", func(t *testing.T) {
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving unchanged posture check should not trigger account peers update and not send peer update
// since there is no change in the network map
t.Run("saving unchanged posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing posture check from policy should trigger account peers update and send peer update
t.Run("removing posture check from policy", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policy.SourcePostureChecks = []string{}
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting unused posture check should not trigger account peers update and not send peer update
t.Run("deleting unused posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
policy = Policy{
ID: "policyB",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Updating linked posture check to policy where destination has peers but source does not
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
})
policy = Policy{
ID: "policyB",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Updating linked posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{
ID: "policyB",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
account := &Account{
Policies: []*Policy{
{
ID: "policyA",
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
}
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
assert.True(t, result)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
assert.False(t, result)
})
t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
assert.False(t, result)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
assert.True(t, result)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
assert.True(t, result)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
assert.False(t, result)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
assert.False(t, result)
})
}

View File

@@ -237,7 +237,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
am.updateAccountPeers(ctx, account)
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for route: %s", newRoute.ID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -313,6 +317,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
@@ -320,7 +325,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
am.updateAccountPeers(ctx, account)
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for route: %s", routeToSave.ID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -350,7 +359,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
am.updateAccountPeers(ctx, account)
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for route: %s", routy.ID)
}
return nil
}
@@ -641,3 +654,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
}
return &portInfo
}
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
}

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"testing"
"time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
@@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
})
}
func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{},
},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
})
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) {
route := route.Route{
ID: "testingRoute1",
Network: netip.MustParsePrefix("100.65.250.202/32"),
NetID: "superNet",
NetworkType: route.IPv4Network,
PeerGroups: []string{"groupA"},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{"groupA"},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
route := route.Route{
ID: "testingRoute2",
Network: netip.MustParsePrefix("192.0.2.0/32"),
NetID: "superNet",
NetworkType: route.IPv4Network,
PeerGroups: []string{routeGroup3},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup3},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
baseRoute := route.Route{
ID: "testingRoute3",
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superNet",
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
}
// Creating route should update account peers and send peer update
t.Run("creating route with a routing peer", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
newRoute, err := manager.CreateRoute(
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
)
require.NoError(t, err)
baseRoute = *newRoute
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Updating the route should update account peers and send peer update when there is peers in group
t.Run("updating route", func(t *testing.T) {
baseRoute.Groups = []string{routeGroup1, routeGroup2}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Updating unchanged route should update account peers and not send peer update
t.Run("updating unchanged route", func(t *testing.T) {
baseRoute.Groups = []string{routeGroup1, routeGroup2}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting the route should update account peers and send peer update
t.Run("deleting route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Adding peer to route peer groups that do not have any peers should update account peers and send peer update
t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) {
newRoute := route.Route{
Network: netip.MustParsePrefix("192.168.12.0/16"),
NetID: "superNet",
NetworkType: route.IPv4Network,
PeerGroups: []string{"groupB"},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
}
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Adding peer to route groups that do not have any peers should update account peers and send peer update
t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) {
newRoute := route.Route{
Network: netip.MustParsePrefix("192.168.13.0/16"),
NetID: "superNet",
NetworkType: route.IPv4Network,
PeerGroups: []string{"groupB"},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{"groupC"},
}
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
require.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -323,8 +323,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
}
}()
am.updateAccountPeers(ctx, account)
return newKey, nil
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@@ -352,3 +353,73 @@ func TestSetupKey_Copy(t *testing.T) {
key.UpdatedAt, key.AutoGroups)
}
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
policy := Policy{
ID: "policy",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"group"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
var setupKey *SetupKey
// Creating setup key should not update account peers and not send peer update
t.Run("creating setup key", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Saving setup key should not update account peers and not send peer update
t.Run("saving setup key", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}

View File

@@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
return nil
}
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
accountCopy := Account{
Domain: domain,
DomainCategory: category,
IsDomainPrimaryAccount: isPrimaryDomain,
}
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.WithContext(ctx).Model(&Account{}).
Select(fieldsToUpdate).
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account %s", accountID)
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting users from store")
}
return users, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
@@ -1117,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
// TODO: This fix is accepted for now, but if we need to handle this more frequently
// we may need to reconsider changing the types.
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
if s.storeEngine == PostgresStoreEngine {
query = query.Order("json_array_length(peers::json) DESC")
} else {
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")

View File

@@ -1191,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
})
assert.NoError(t, err)
}
func TestSqlite_GetAccoundUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
t.Run("Should update attributes with public domain", func(t *testing.T) {
require.NoError(t, err)
domain := "example.com"
category := "public"
IsDomainPrimaryAccount := false
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, domain, account.Domain)
require.Equal(t, category, account.DomainCategory)
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
})
t.Run("Should update attributes with private domain", func(t *testing.T) {
require.NoError(t, err)
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, domain, account.Domain)
require.Equal(t, category, account.DomainCategory)
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
})
t.Run("Should fail when account does not exist", func(t *testing.T) {
require.NoError(t, err)
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
require.Error(t, err)
})
}
func TestSqlite_GetGroupByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
require.NoError(t, err)
require.Equal(t, "All", group.Name)
}

View File

@@ -58,9 +58,11 @@ type Store interface {
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error

View File

@@ -26,8 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
INSERT INTO installations VALUES(1,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);

View File

@@ -2,9 +2,13 @@ package server
import (
"context"
"fmt"
"runtime/debug"
"sync"
"time"
"github.com/netbirdio/netbird/management/server/differs"
"github.com/r3labs/diff/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
@@ -14,14 +18,17 @@ import (
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
Update *proto.SyncResponse
NetworkMap *NetworkMap
}
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
peerUpdateMessage map[string]*UpdateMessage
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.Mutex
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
@@ -29,9 +36,10 @@ type PeersUpdateManager struct {
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.Mutex{},
metrics: metrics,
peerChannels: make(map[string]chan *UpdateMessage),
peerUpdateMessage: make(map[string]*UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
@@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
start := time.Now()
var found, dropped bool
// skip sending sync update to the peer if there is no change in update message,
// it will not check on turn credential refresh as we do not send network map or client posture checks
if update.NetworkMap != nil {
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
if !updated {
return
}
}
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
@@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
}
}()
if update.NetworkMap != nil {
lastSentUpdate := p.peerUpdateMessage[peerID]
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
return
}
p.peerUpdateMessage[peerID] = update
}
if channel, ok := p.peerChannels[peerID]; ok {
found = true
select {
@@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
closed = true
delete(p.peerChannels, peerID)
close(channel)
delete(p.peerUpdateMessage, peerID)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
@@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
delete(p.peerUpdateMessage, peerID)
}
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
@@ -170,3 +200,77 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
return ok
}
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
p.channelsMux.RLock()
lastSentUpdate := p.peerUpdateMessage[peerID]
p.channelsMux.RUnlock()
if lastSentUpdate != nil {
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
if err != nil {
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
return false
}
if !updated {
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
return false
}
}
return true
}
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
defer func() {
if r := recover(); r != nil {
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
}
isNew, err = true, nil
}()
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
log.Tracef("new network map serial: %d not greater than last sent: %d, skip sending update",
lastSentUpdate.Update.NetworkMap.GetSerial(),
currUpdateToSend.Update.NetworkMap.GetSerial(),
)
return false, nil
}
differ, err := diff.NewDiffer(
diff.CustomValueDiffers(&differs.NetIPAddr{}),
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
)
if err != nil {
return false, fmt.Errorf("failed to create differ: %v", err)
}
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
changelog, err := differ.Diff(lastSentFiles, currFiles)
if err != nil {
return false, fmt.Errorf("failed to diff checks: %v", err)
}
if len(changelog) > 0 {
return true, nil
}
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
if err != nil {
return false, fmt.Errorf("failed to diff network map: %v", err)
}
return len(changelog) > 0, nil
}
// getChecksFiles returns a list of files from the given checks.
func getChecksFiles(checks []*proto.Checks) []string {
files := make([]string, 0, len(checks))
for _, check := range checks {
files = append(files, check.GetFiles()...)
}
return files
}

View File

@@ -2,10 +2,19 @@ package server
import (
"context"
"net"
"net/netip"
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util"
"github.com/stretchr/testify/assert"
)
// var peersUpdater *PeersUpdateManager
@@ -77,3 +86,465 @@ func TestCloseChannel(t *testing.T) {
t.Error("Error closing the channel")
}
}
func TestHandlePeerMessageUpdate(t *testing.T) {
tests := []struct {
name string
peerID string
existingUpdate *UpdateMessage
newUpdate *UpdateMessage
expectedResult bool
}{
{
name: "update message with turn credentials update",
peerID: "peer",
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{},
},
},
expectedResult: true,
},
{
name: "update message for peer without existing update",
peerID: "peer1",
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
expectedResult: true,
},
{
name: "update message with no changes in update",
peerID: "peer2",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
expectedResult: false,
},
{
name: "update message with changes in checks",
peerID: "peer3",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
expectedResult: true,
},
{
name: "update message with lower serial number",
peerID: "peer4",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
expectedResult: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewPeersUpdateManager(nil)
ctx := context.Background()
if tt.existingUpdate != nil {
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
}
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
assert.Equal(t, tt.expectedResult, result)
})
}
}
func TestIsNewPeerUpdateMessage(t *testing.T) {
t.Run("Unchanged value", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.False(t, message)
})
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.False(t, message)
})
t.Run("Updating routes network", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating routes groups", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating network map peers", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newPeer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.4"),
SSHEnabled: true,
Key: "peer4-key",
DNSLabel: "peer4",
SSHKey: "peer4-ssh-key",
}
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating process check", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.False(t, message)
newUpdateMessage3 := createMockUpdateMessage(t)
newUpdateMessage3.Update.Checks = []*proto.Checks{}
newUpdateMessage3.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
assert.NoError(t, err)
assert.True(t, message)
newUpdateMessage4 := createMockUpdateMessage(t)
check := &posture.Checks{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/local/netbird",
MacPath: "/usr/bin/netbird",
},
},
},
},
}
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage4.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
assert.NoError(t, err)
assert.True(t, message)
newUpdateMessage5 := createMockUpdateMessage(t)
check = &posture.Checks{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/local/netbird",
},
},
},
},
}
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage5.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating DNS configuration", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newDomain := "newexample.com"
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
newDomain,
)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating peer IP", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating firewall rule", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Add new firewall rule", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newRule := &FirewallRule{
PeerIP: "192.168.1.3",
Direction: firewallRuleDirectionOUT,
Action: string(PolicyTrafficActionDrop),
Protocol: string(PolicyRuleProtocolUDP),
Port: "53",
}
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Removing nameserver", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating name server IP", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating custom DNS zone", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.True(t, message)
})
}
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
t.Helper()
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
t.Fatal(err)
}
domainList, err := domain.FromStringList([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.1"),
SSHEnabled: true,
Key: "peer-key",
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
secretManager := NewTimeBasedAuthSecretsManager(
NewPeersUpdateManager(nil),
&TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{
Duration: defaultDuration,
},
Secret: "secret",
Turns: []*Host{TurnTestHost},
},
&Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "secret",
},
)
networkMap := &NetworkMap{
Network: &Network{Net: *ipNet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
{
ID: "route1",
Network: netip.MustParsePrefix("10.0.0.0/24"),
KeepRoute: true,
NetID: "route1",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"test1", "test2"},
},
{
ID: "route2",
Domains: domainList,
KeepRoute: true,
NetID: "route2",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"test1", "test2"},
},
},
DNSConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
{
ID: "ns1",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
checks := []*posture.Checks{
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/bin/netbird",
},
},
},
},
},
}
dnsCache := &DNSConfigCache{}
turnToken, err := secretManager.GenerateTurnToken()
if err != nil {
t.Fatal(err)
}
relayToken, err := secretManager.GenerateRelayToken()
if err != nil {
t.Fatal(err)
}
return &UpdateMessage{
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
NetworkMap: networkMap,
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"
@@ -473,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
}
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
return err
}
@@ -485,15 +486,24 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
}
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for user: %s", targetUserID)
}
return nil
}
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error {
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
peers, err := account.FindUserPeers(targetUserID)
if err != nil {
return status.Errorf(status.Internal, "failed to find user peers")
return false, status.Errorf(status.Internal, "failed to find user peers")
}
hadPeers := len(peers) > 0
if !hadPeers {
return false, nil
}
peerIDs := make([]string, 0, len(peers))
@@ -501,7 +511,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
peerIDs = append(peerIDs, peer.ID)
}
return am.deletePeers(ctx, account, peerIDs, initiatorUserID)
return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID)
}
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
@@ -745,6 +755,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
updatedUsers := make([]*UserInfo, 0, len(updates))
var (
expiredPeers []*nbpeer.Peer
userIDs []string
eventsToStore []func()
)
@@ -753,6 +764,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
userIDs = append(userIDs, update.Id)
oldUser := account.Users[update.Id]
if oldUser == nil {
if !addIfNotExists {
@@ -816,8 +829,10 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, err
}
if account.Settings.GroupsPropagationEnabled {
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for user: %v", userIDs)
}
for _, storeEvent := range eventsToStore {
@@ -1167,7 +1182,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
}
var allErrors error
var (
allErrors error
updateAccountPeers bool
)
deletedUsersMeta := make(map[string]map[string]any)
for _, targetUserID := range targetUserIDs {
@@ -1193,12 +1211,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
continue
}
if hadPeers {
updateAccountPeers = true
}
delete(account.Users, targetUserID)
deletedUsersMeta[targetUserID] = meta
}
@@ -1208,7 +1230,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return fmt.Errorf("failed to delete users: %w", err)
}
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for user: %v", targetUserIDs)
}
for targetUserID, meta := range deletedUsersMeta {
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
@@ -1217,11 +1243,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return allErrors
}
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
return nil, err
return nil, false, err
}
if !isNil(am.idpManager) {
@@ -1232,16 +1258,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
if err != nil {
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
return nil, err
return nil, false, err
}
} else {
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
}
}
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
if err != nil {
return nil, err
return nil, false, err
}
u, err := account.FindUser(targetUserID)
@@ -1254,7 +1280,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
tuCreatedAt = u.CreatedAt
}
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
@@ -1333,3 +1359,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa
}
return nil, false
}
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account.
func areUsersLinkedToPeers(account *Account, userIDs []string) bool {
for _, peer := range account.Peers {
if slices.Contains(userIDs, peer.UserID) {
return true
}
}
return false
}

View File

@@ -10,9 +10,12 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
@@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
})
}
}
func TestUserAccountPeersUpdate(t *testing.T) {
// account groups propagation is enabled
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
require.NoError(t, err)
policy := Policy{
ID: "policy",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a new regular user should not update account peers and not send peer update
t.Run("creating new regular user with no groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
Id: "regularUser1",
AccountID: account.Id,
Role: UserRoleUser,
Issued: UserIssuedAPI,
}, true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// updating user with no linked peers should not update account peers and not send peer update
t.Run("updating user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
Id: "regularUser1",
AccountID: account.Id,
Role: UserRoleUser,
Issued: UserIssuedAPI,
}, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// deleting user with no linked peers should not update account peers and not send peer update
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1")
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// create a user and add new peer with the user
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
Id: "regularUser2",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
}, true)
require.NoError(t, err)
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
require.NoError(t, err)
// updating user with linked peers should update account peers and send peer update
t.Run("updating user with linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
Id: "regularUser2",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
}, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
})
// deleting user with linked peers should update account peers and send peer update
t.Run("deleting user with linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, peer4UpdMsg)
close(done)
}()
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2")
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -16,8 +16,10 @@ const (
type Metrics struct {
metric.Meter
TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter
TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peers metric.Int64UpDownCounter
peerActivityChan chan string
@@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err
}
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
m := &Metrics{
Meter: meter,
TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv,
peers: peers,
Meter: meter,
TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv,
AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers,
ctx: ctx,
peerActivityChan: make(chan string, 10),
@@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
m.peerLastActive[id] = time.Time{}
}
// RecordAuthenticationTime measures the time taken for peer authentication
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// RecordPeerStoreTime measures the time to store the peer in map
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1)
@@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
}
}
}
func getStandardBucketBoundaries() []float64 {
return []float64{
0.1,
0.5,
1,
5,
10,
50,
100,
500,
1000,
5000,
10000,
}
}

153
relay/server/handshake.go Normal file
View File

@@ -0,0 +1,153 @@
package server
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
)
// preparedMsg contains the marshalled success response messages
type preparedMsg struct {
responseHelloMsg []byte
responseAuthMsg []byte
}
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
rhm, err := marshalResponseHelloMsg(instanceURL)
if err != nil {
return nil, err
}
ram, err := messages.MarshalAuthResponse(instanceURL)
if err != nil {
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
}
return &preparedMsg{
responseHelloMsg: rhm,
responseAuthMsg: ram,
}, nil
}
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
addr := &address.Address{URL: instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal response address: %w", err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
}
return responseMsg, nil
}
type handshake struct {
conn net.Conn
validator auth.Validator
preparedMsg *preparedMsg
handshakeMethodAuth bool
peerID string
}
func (h *handshake) handshakeReceive() ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
}
var (
bytePeerID []byte
peerID string
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
case messages.MsgTypeAuth:
h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
return nil, err
}
h.peerID = peerID
return bytePeerID, nil
}
func (h *handshake) handshakeResponse() error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
} else {
responseMsg = h.preparedMsg.responseHelloMsg
}
if _, err := h.conn.Write(responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}
return nil
}
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}

View File

@@ -7,16 +7,13 @@ import (
"net/url"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics"
)
@@ -28,6 +25,7 @@ type Relay struct {
store *Store
instanceURL string
preparedMsg *preparedMsg
closed bool
closeMu sync.RWMutex
@@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return nil, fmt.Errorf("get instance URL: %v", err)
}
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("prepare message: %v", err)
}
return r, nil
}
@@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
// Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now()
r.closeMu.RLock()
defer r.closeMu.RUnlock()
if r.closed {
return
}
peerID, err := r.handshake(conn)
h := handshake{
conn: conn,
validator: r.validator,
preparedMsg: r.preparedMsg,
}
peerID, err := h.handshakeReceive()
if err != nil {
log.Errorf("failed to handshake: %s", err)
cErr := conn.Close()
if cErr != nil {
if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
@@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
r.store.AddPeer(peer)
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
@@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()
if err := h.handshakeResponse(); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close()
}
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
}
// Shutdown closes the relay server
@@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
func (r *Relay) InstanceURL() string {
return r.instanceURL
}
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
var (
responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
}
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}

View File

@@ -86,7 +86,7 @@ download_release_binary() {
# Unzip the app and move to INSTALL_DIR
unzip -q -o "$BINARY_NAME"
mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
else
${SUDO} mkdir -p "$INSTALL_DIR"
tar -xzvf "$BINARY_NAME"

View File

@@ -1,126 +0,0 @@
package util_test
import (
"crypto/md5"
"encoding/hex"
"io"
"os"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/netbirdio/netbird/util"
)
var _ = Describe("Client", func() {
var (
tmpDir string
)
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
}
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
err := os.RemoveAll(tmpDir)
Expect(err).NotTo(HaveOccurred())
})
Describe("Config", func() {
Context("in JSON format", func() {
It("should be written and read successfully", func() {
m := make(map[string]string)
m["key1"] = "value1"
m["key2"] = "value2"
arr := []string{"value1", "value2"}
written := &TestConfig{
SomeMap: m,
SomeArray: arr,
SomeField: 99,
}
err := util.WriteJson(tmpDir+"/testconfig.json", written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
})
})
})
Describe("Copying file contents", func() {
Context("from one file to another", func() {
It("should be successful", func() {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := util.WriteJson(src, []string{"1", "2", "3"})
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents(src, dst)
Expect(err).NotTo(HaveOccurred())
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
Expect(err).NotTo(HaveOccurred())
dstFile, err := os.Open(dst)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashSrc, srcFile)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashDst, dstFile)
Expect(err).NotTo(HaveOccurred())
err = srcFile.Close()
Expect(err).NotTo(HaveOccurred())
err = dstFile.Close()
Expect(err).NotTo(HaveOccurred())
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
})
})
})
Describe("Handle config file without full path", func() {
Context("config file handling", func() {
It("should be successful", func() {
written := &TestConfig{
SomeField: 123,
}
cfgFile := "test_cfg.json"
defer os.Remove(cfgFile)
err := util.WriteJson(cfgFile, written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(cfgFile, &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
})
})
})
})

View File

@@ -1,12 +1,142 @@
package util
import (
"crypto/md5"
"encoding/hex"
"io"
"os"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
}
func TestConfigJSON(t *testing.T) {
tests := []struct {
name string
config *TestConfig
expectedError bool
}{
{
name: "Valid JSON config",
config: &TestConfig{
SomeMap: map[string]string{"key1": "value1", "key2": "value2"},
SomeArray: []string{"value1", "value2"},
SomeField: 99,
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
require.NoError(t, err)
require.NotNil(t, read)
require.Equal(t, tt.config.SomeMap["key1"], read.(*TestConfig).SomeMap["key1"])
require.Equal(t, tt.config.SomeMap["key2"], read.(*TestConfig).SomeMap["key2"])
require.ElementsMatch(t, tt.config.SomeArray, read.(*TestConfig).SomeArray)
require.Equal(t, tt.config.SomeField, read.(*TestConfig).SomeField)
})
}
}
func TestCopyFileContents(t *testing.T) {
tests := []struct {
name string
srcContent []string
expectedError bool
}{
{
name: "Copy file contents successfully",
srcContent: []string{"1", "2", "3"},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := WriteJson(src, tt.srcContent)
require.NoError(t, err)
err = CopyFileContents(src, dst)
require.NoError(t, err)
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
require.NoError(t, err)
defer func() {
_ = srcFile.Close()
}()
dstFile, err := os.Open(dst)
require.NoError(t, err)
defer func() {
_ = dstFile.Close()
}()
_, err = io.Copy(hashSrc, srcFile)
require.NoError(t, err)
_, err = io.Copy(hashDst, dstFile)
require.NoError(t, err)
require.Equal(t, hex.EncodeToString(hashSrc.Sum(nil)[:16]), hex.EncodeToString(hashDst.Sum(nil)[:16]))
})
}
}
func TestHandleConfigFileWithoutFullPath(t *testing.T) {
tests := []struct {
name string
config *TestConfig
expectedError bool
}{
{
name: "Handle config file without full path",
config: &TestConfig{
SomeField: 123,
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfgFile := "test_cfg.json"
defer func() {
_ = os.Remove(cfgFile)
}()
err := WriteJson(cfgFile, tt.config)
require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{})
require.NoError(t, err)
require.NotNil(t, read)
})
}
}
func TestReadJsonWithEnvSub(t *testing.T) {
type Config struct {
CertFile string `json:"CertFile"`