Compare commits

...

56 Commits

Author SHA1 Message Date
Maycon Santos
268e801ec5 Ignore network monitor checks for software interfaces (#2302)
ignore checks for Teredo and ISATAP interfaces
2024-07-22 19:44:15 +02:00
Maycon Santos
788f130941 Retry management connection only on context canceled (#2301) 2024-07-22 15:49:25 +02:00
Maycon Santos
926e11b086 Remove default allow for UDP on unmatched packet (#2300)
This fixes an issue where UDP rules were ineffective for userspace clients (Windows/macOS)
2024-07-22 15:35:17 +02:00
Carlos Hernandez
0a8c78deb1 Minor fix local dns search domain (#2287) 2024-07-19 16:44:12 +02:00
Maycon Santos
c815ad86fd Fix macOS DNS unclean shutdown restore call on startup (#2286)
previously, we called the restore method from the startup when there was an unclean shutdown. But it never had the state keys to clean since they are stored in memory

this change addresses the issue by falling back to default values when restoring the host's DNS
2024-07-18 18:06:09 +02:00
Carlos Hernandez
ef1a39cb01 Refactor macOS system DNS configuration (#2284)
On macOS use the recommended settings for providing split DNS. As per
the docs an empty string will force the configuration to be the default.
In order to to support split DNS an additional service config is added
for the local server and search domain settings.

see: https://developer.apple.com/documentation/devicemanagement/vpn/dns
2024-07-18 16:39:41 +02:00
Maycon Santos
c900fa81bb Remove copy functions from signal (#2285)
remove migration function for wiretrustee directories to netbird
2024-07-18 12:15:14 +02:00
Maycon Santos
9a6de52dd0 Check if route interface is a Microsoft ISATAP device (#2282)
check if the nexthop interfaces are Microsoft ISATAP devices and ignore their suffixes when comparing them
2024-07-17 23:49:09 +02:00
Maycon Santos
19147f518e Add faster availability DNS probe and update test domain to .com (#2280)
* Add faster availability DNS probe and update test domain to .com

- Count success queries and compare it before doing after network map probes.

- Reduce the first dns probe to 500ms

- Updated test domain with com instead of . due to Palo alto DNS proxy server issues

* use fqdn

* Update client/internal/dns/upstream.go

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2024-07-17 23:48:37 +02:00
Viktor Liu
e78ec2e985 Don't add exclusion routes for IPs that are part of connected networks (#2258)
This prevents arp/ndp issues on macOS leading to unreachability of that IP.
2024-07-17 19:50:06 +02:00
pascal-fischer
95d725f2c1 Wait on daemon down (#2279) 2024-07-17 16:26:06 +02:00
benniekiss
4fad0e521f Support custom SSL certificates for the signal service (#2257) 2024-07-16 20:44:21 +02:00
ctrl-zzz
a711e116a3 fix: save peer status correctly in sqlstore (#2262)
* fix: save peer status correctly in sqlstore

https://github.com/netbirdio/netbird/issues/2110#issuecomment-2162768273

* feat: update test function

* refactor: simplify status update
2024-07-16 18:38:12 +03:00
Maycon Santos
668d229b67 Fix metric label typo (#2278) 2024-07-16 16:55:57 +02:00
Maycon Santos
7c595e8493 Add get_registration_delay_milliseconds metric (#2275) 2024-07-16 15:36:51 +02:00
Jakub Kołodziejczak
f9c59a7131 Refactor log util (#2276) 2024-07-16 11:50:35 +02:00
Jakub Kołodziejczak
1d6f5482dd feat(client): send logs to syslog (#2259) 2024-07-16 10:19:58 +02:00
Carlos Hernandez
12ff93ba72 Ignore no unique route updates (#2266) 2024-07-16 10:19:01 +02:00
Maycon Santos
88d1c5a0fd fix forwarded metrics (#2273) 2024-07-16 10:14:30 +02:00
Bethuel Mmbaga
1537b0f5e7 Add batch save/update for groups and users (#2245)
* Add functionality to update multiple users

* Remove SaveUsers from DefaultAccountManager

* Add SaveGroups method to AccountManager interface

* Refactoring

* Add SaveUsers and SaveGroups methods to store interface

* Refactor method SaveAccount to SaveUsers and SaveGroups

The method SaveAccount in user.go and group.go files was split into two separate methods. Now, user-specific data is handled by SaveUsers and group-specific data is handled by SaveGroups method. This provides a cleaner and more efficient way to save user and group data.

* Add account ID to user and group in SqlStore

* Refactor SaveUsers and SaveGroups in store

* Remove unnecessary ID assignment in SaveUsers and SaveGroups
2024-07-15 17:04:06 +03:00
Maycon Santos
2577100096 Limit GUI process execution to one per UID (#2267)
replaces PID with checking process name and path and UID checks
2024-07-15 14:53:52 +02:00
Zoltan Papp
bc09348f5a Add logging option for wg device (#2271) 2024-07-15 14:45:18 +02:00
Edouard Vanbelle
d5ba2ef6ec fix 2260: fallback serial to Board (#2263) 2024-07-15 14:43:50 +02:00
pascal-fischer
47752e1573 Support DNS routes on iOS (#2254) 2024-07-15 10:40:57 +02:00
Maycon Santos
58fbc1249c Fix parameter limit issue for Postgres store (#2261)
Added CreateBatchSize for both SQL stores and updated tests to test large accounts with Postgres, too. Increased the account peer size to 6K.
2024-07-12 09:28:53 +02:00
dependabot[bot]
1cc341a268 Bump google.golang.org/grpc from 1.64.0 to 1.64.1 (#2248)
Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.64.0 to 1.64.1.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.64.0...v1.64.1)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-12 08:52:27 +02:00
Viktor Liu
89df6e7242 Get client ui locale on windows natively (#2251) 2024-07-12 08:25:33 +02:00
Maycon Santos
f74646a3ac Add release version to windows binaries and update sign pipeline version (#2256) 2024-07-11 19:06:55 +02:00
pascal-fischer
e8c2fafccd Avoid empty domain overwrite (#2252) 2024-07-10 14:08:35 +02:00
Maycon Santos
85e991ff78 Fix issue with canceled context before pushing metrics and decreasing pushing interval (#2235)
Fix a bug where the post context was canceled before sending metrics to the server.

The interval time was decreased, and an optional environment variable NETBIRD_METRICS_INTERVAL_IN_SECONDS was added to control the interval time.

* update doc URL
2024-07-04 19:15:59 +02:00
Maycon Santos
f9845e53a0 Sort routes by ID and remove DNS routes from overlapping list (#2234) 2024-07-04 16:50:07 +02:00
pascal-fischer
765aba2c1c Add context to throughout the project and update logging (#2209)
propagate context from all the API calls and log request ID, account ID and peer ID

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-07-03 11:33:02 +02:00
Zoltan Papp
7cb81f1d70 Fix nil pointer exception in case of error (#2230) 2024-07-02 18:18:14 +02:00
Viktor Liu
cea19de667 Debounce network monitor restarts (#2225) 2024-07-02 17:09:00 +02:00
Bethuel Mmbaga
29e5eceb6b Fix linux serial number retrieval (#2206)
* Change source of serial number in sysInfo function

The serial number returned by the sysInfo function in info_linux.go has been fixed. Previously, it was incorrectly fetched from the Chassis object. Now it is correctly fetched from the Product object. This aligns better with the expected system info retrieval method.

* Fallback to product.Serial in sys info

In case of the chassis is "Default String" or empty then try to use product.serial

---------

Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2024-07-02 13:19:08 +02:00
dependabot[bot]
0f63737330 Bump golang.org/x/image from 0.10.0 to 0.18.0 (#2205)
Bumps [golang.org/x/image](https://github.com/golang/image) from 0.10.0 to 0.18.0.
- [Commits](https://github.com/golang/image/compare/v0.10.0...v0.18.0)

---
updated-dependencies:
- dependency-name: golang.org/x/image
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-02 13:12:28 +02:00
Viktor Liu
bf518c5fba Remove interface network monitor checks (#2223) 2024-07-02 12:41:15 +02:00
Maycon Santos
eab6183a8e Add stack trace when saving empty domains (#2228)
added temporary domain check for existing accounts to trace where the issue originated

Refactor save account due to complexity score
2024-07-02 12:40:26 +02:00
Yxnt
4517da8b3a Feat: Client UI Multiple Language Support (#2192)
Signed-off-by: Yxnt <jyxnt1@gmail.com>
2024-07-02 12:47:26 +03:00
Maycon Santos
9c0d923124 fix: client/Dockerfile to reduce vulnerabilities (#2220)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249236
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249236
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249265
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249265
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249419

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2024-07-02 09:42:30 +02:00
Maycon Santos
6857734c48 add MACOSX_DEPLOYMENT_TARGET environment to control GUI build target (#2221)
Add MACOSX_DEPLOYMENT_TARGET and MACOS_DEPLOYMENT_TARGET to target build compatible with macOS 11+ instead of relying on the builder's local Xcode version.
2024-07-01 17:59:09 +02:00
Maycon Santos
3b019800f8 Remove DNSSEC parameters and configure AuthenticatedData (#2208) 2024-06-27 18:36:24 +02:00
Maycon Santos
4cd4f88666 Add multiple tabs for route selection (#2198)
Add all routes, overlapping and exit routes tabs
2024-06-27 14:32:30 +02:00
Maycon Santos
d2157bda66 Set EDNS0 when no extra options are set by the dns client (#2195) 2024-06-25 17:18:04 +02:00
Maycon Santos
43a8ba97e3 Add log config and removed domain (#2194)
removed domainname for coturn service as it is needed only for SSL configs

Added log configuration for each service with a rotation and max size

ensure ZITADEL_DATABASE=postgres works
2024-06-25 13:54:09 +02:00
Robert Neumann
17874771cc Feature/Use Zitadel Postgres Integration by default (#2181)
replaces cockroachDB as default DB for Zitadel in the getting started script to deploy script. Users can switch back to cockroachDB by setting the environment variable ZITADEL_DATABASE to cockroach.
2024-06-25 11:10:11 +02:00
Viktor Liu
f6ccf6b97a Improve windows network monitor (#2184)
* Allow other states for windows neighbor network monitor

* Allow windows route network monitor to check for multiple default routes
2024-06-25 10:35:51 +02:00
Viktor Liu
6aae797baf Add loopback ignore rule to nat chains (#2190)
This makes sure loopback traffic is not affected by NAT
2024-06-25 09:43:36 +02:00
Maycon Santos
aca054e51e Using macOS-latest to build GUI (#2189) 2024-06-25 09:34:02 +02:00
Maycon Santos
10cee8f46e Use selector to display dns routes in GUI (#2185)
Use select widget for dns routes on GUI
2024-06-24 16:18:00 +02:00
Viktor Liu
628673db20 Lower retry interval on dns resolve failure (#2176) 2024-06-24 11:55:07 +02:00
Bethuel Mmbaga
eaa31c2dc6 Optimize process checks database read (#2182)
* Add posture checks to peer management

This commit includes posture checks to the peer management logic. The AddPeer, SyncPeer and LoginPeer functions now return a list of posture checks along with the peer and network map.

* Update peer methods to return posture checks

* Refactor

* return early if there is no posture checks

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-06-22 17:41:16 +03:00
Zoltan Papp
25723e9b07 Do not use eBPF proxy in case of USP mode (#2180) 2024-06-22 15:33:10 +02:00
Robert Neumann
3cf4d5758f Update Zitadel and CockroachDB Container Image Version (#2169)
* fix type in docker compose

* Update docker compose cockroachdb to latest-23.2 and zitadel to 2.54.3
2024-06-22 12:44:45 +02:00
Bethuel Mmbaga
fc15ee6351 auto migrate older management to sqlite (#2170) 2024-06-20 19:45:57 +02:00
Viktor Liu
4a3e78fb0f Fix windows network monitor next hop ip log (#2168) 2024-06-20 16:59:33 +02:00
189 changed files with 5188 additions and 3653 deletions

View File

@@ -10,8 +10,10 @@ on:
env:
SIGN_PIPE_VER: "v0.0.11"
SIGN_PIPE_VER: "v0.0.12"
GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -23,6 +25,13 @@ jobs:
env:
flags: ""
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
@@ -68,18 +77,18 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso 386
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
- name: Generate windows syso arm
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
- name: Generate windows syso arm64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
@@ -121,6 +130,13 @@ jobs:
release_ui:
runs-on: ubuntu-latest
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
@@ -151,10 +167,11 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
@@ -173,7 +190,7 @@ jobs:
retention-days: 3
release_ui_darwin:
runs-on: macos-11
runs-on: macos-latest
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@@ -178,34 +178,79 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- name: run script
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen
- name: test Caddy file gen postgres
run: test -f Caddyfile
- name: test docker-compose file gen
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml
- name: test management.json file gen
- name: test management.json file gen postgres
run: test -f management.json
- name: test turnserver.conf file gen
- name: test turnserver.conf file gen postgres
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen
- name: test zitadel.env file gen postgres
run: test -f zitadel.env
- name: test dashboard.env file gen
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

View File

@@ -3,8 +3,10 @@ builds:
- id: netbird-ui-darwin
dir: client/ui
binary: netbird-ui
env: [CGO_ENABLED=1]
env:
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
goos:
- darwin
goarch:

View File

@@ -1,4 +1,4 @@
FROM alpine:3.18.5
FROM alpine:3.19
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -121,7 +121,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")

View File

@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
t.Fatal(err)
}
@@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil
}
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
@@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}

View File

@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.InsertRoutingRules(pair)
return m.router.AddRoutingRules(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@@ -22,6 +22,8 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
@@ -126,6 +128,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -138,28 +156,28 @@ func (r *router) createContainers() error {
return nil
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
@@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router
}
}
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,

View File

@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.InputPair)
err = manager.AddRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()

View File

@@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true
}
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true
}

View File

@@ -15,6 +15,12 @@ type hostManager interface {
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
}
type SystemDNSSettings struct {
Domains []string
ServerIP string
ServerPort int
}
type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"`

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
@@ -18,7 +19,7 @@ import (
const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses"
@@ -28,12 +29,12 @@ const (
scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"
)
type systemConfigurator struct {
// primaryServiceID primary interface in the system. AKA the interface with the default route
primaryServiceID string
createdKeys map[string]struct{}
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
}
func newHostManager() (hostManager, error) {
@@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil {
return fmt.Errorf("add dns setup for all: %w", err)
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return fmt.Errorf("remote key from system config: %w", err)
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
}
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
@@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string
)
err = s.recordSystemDNSSettings(true)
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
}
}
for _, dConf := range config.Domains {
if dConf.Disabled {
continue
@@ -110,23 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
func (s *systemConfigurator) restoreHostDNS() error {
lines := ""
for key := range s.createdKeys {
lines += buildRemoveKeyOperation(key)
keys := s.getRemovableKeysWithDefaults()
for _, key := range keys {
keyType := "search"
if strings.Contains(key, matchSuffix) {
keyType = "match"
}
log.Infof("removing %s domains from system", keyType)
}
if s.primaryServiceID != "" {
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
log.Infof("restoring DNS resolver configuration for system")
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return fmt.Errorf("clean system: %w", err)
err := s.removeKeyFromSystemConfig(key)
if err != nil {
log.Errorf("failed to remove %s domains from system: %s", keyType, err)
}
}
if err := removeUncleanShutdownIndicator(); err != nil {
@@ -136,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error {
return nil
}
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
for key := range s.createdKeys {
keys = append(keys, key)
}
return keys
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
@@ -148,6 +155,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
return nil
}
func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration")
return err
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server")
}
return nil
}
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil
}
systemDNSSettings, err := s.getSystemDNSSettings()
if err != nil {
return fmt.Errorf("couldn't get current DNS config: %w", err)
}
s.systemDNSSettings = systemDNSSettings
return nil
}
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
primaryServiceKey, _, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
}
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
line := buildCommandLine("show", dnsServiceKey, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
}
var dnsSettings SystemDNSSettings
inSearchDomainsArray := false
inServerAddressesArray := false
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
switch {
case strings.HasPrefix(line, "DomainName :"):
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
case line == "SearchDomains : <array> {":
inSearchDomainsArray = true
continue
case line == "ServerAddresses : <array> {":
inServerAddressesArray = true
continue
case line == "}":
inSearchDomainsArray = false
inServerAddressesArray = false
}
if inSearchDomainsArray {
searchDomain := strings.Split(line, " : ")[1]
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
dnsSettings.ServerIP = address
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
if err := scanner.Err(); err != nil {
return dnsSettings, err
}
// default to 53 port
dnsSettings.ServerPort = 53
return dnsSettings, nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
@@ -194,23 +292,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
return nil
}
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key: %w", err)
}
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line)
@@ -239,19 +320,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil
}
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@@ -24,7 +24,7 @@ const (
probeTimeout = 2 * time.Second
)
const testRecord = "."
const testRecord = "com."
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
@@ -42,6 +42,7 @@ type upstreamResolverBase struct {
upstreamServers []string
disabled bool
failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
@@ -78,6 +79,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():
@@ -119,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
@@ -167,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() {
default:
}
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
return
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
@@ -178,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() {
wg.Add(1)
go func() {
defer wg.Done()
err := u.testNameserver(upstream)
err := u.testNameserver(upstream, 500*time.Millisecond)
if err != nil {
errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
@@ -219,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream); err != nil {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
@@ -239,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate()
u.disabled = false
}
@@ -260,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) {
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) testNameserver(server string) error {
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

View File

@@ -266,8 +266,23 @@ func (e *Engine) Stop() error {
e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@@ -282,8 +297,6 @@ func (e *Engine) Start() error {
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
wgIface, err := e.newWgIface()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@@ -291,6 +304,9 @@ func (e *Engine) Start() error {
}
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {
@@ -1464,6 +1480,15 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
@@ -1472,14 +1497,29 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
debounceTimer.Stop()
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
@@ -1508,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err)
}
//time.Sleep(250 * time.Millisecond)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
@@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -4,6 +4,7 @@ package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
@@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for {
select {
case <-ctx.Done():
@@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
@@ -45,24 +57,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle interface state changes
case unix.RTM_IFINFO:
ifinfo, err := parseInterfaceMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing interface message: %v", err)
continue
}
if msg.Flags&unix.IFF_UP != 0 {
continue
}
if (nexthopv4.Intf == nil || ifinfo.Index != nexthopv4.Intf.Index) && (nexthopv6.Intf == nil || ifinfo.Index != nexthopv6.Intf.Index) {
continue
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
go callback()
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
@@ -94,24 +88,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.InterfaceMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return msg, nil
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {

View File

@@ -19,14 +19,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return errors.New("no interfaces available")
}
linkChan := make(chan netlink.LinkUpdate)
done := make(chan struct{})
defer close(done)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
return fmt.Errorf("subscribe to link updates: %v", err)
}
routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err)
@@ -38,25 +33,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case <-ctx.Done():
return ErrStopped
// handle interface state changes
case update := <-linkChan:
if (nexthopv4.Intf == nil || update.Index != int32(nexthopv4.Intf.Index)) && (nexthopv6.Intf == nil || update.Index != int32(nexthopv6.Intf.Index)) {
continue
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
go callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
go callback()
return nil
}
}
// handle route changes
case route := <-routeChan:
// default route and main table

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -33,12 +34,8 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("get neighbors: %w", err)
}
if n, ok := initialNeighbors[nexthopv4.IP]; ok {
neighborv4 = &n
}
if n, ok := initialNeighbors[nexthopv6.IP]; ok {
neighborv6 = &n
}
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
}
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
@@ -58,6 +55,16 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed(
nexthopv4 systemops.Nexthop,
neighborv4 *systemops.Neighbor,
@@ -87,37 +94,69 @@ func changed(
}
// routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[netip.Prefix]systemops.Route) bool {
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
if !nexthop.IP.IsValid() {
return false
}
var unspec netip.Prefix
if nexthop.IP.Is6() {
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
if isSoftInterface(nexthop.Intf.Name) {
log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
return false
}
if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
oldIntf, newIntf := "<nil>", "<nil>"
if intf != nil {
oldIntf = intf.Name
}
if r.Interface != nil {
newIntf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s from %s (%s) to %s (%s)", r.Destination, nexthop.IP, oldIntf, nexthop.IP, newIntf)
return true
}
} else {
log.Infof("network monitor: default route is gone")
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
if !foundMatchingRoute {
logRouteChange(nexthop.IP, intf)
return true
}
return false
}
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
}
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
@@ -127,7 +166,7 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop.IP]; ok {
if n.State != reachable && n.State != permanent {
if n.State == unreachable || n.State == incomplete {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
@@ -165,18 +204,13 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
return neighbours, nil
}
func getRoutes() (map[netip.Prefix]systemops.Route, error) {
func getRoutes() ([]systemops.Route, error) {
entries, err := systemops.GetRoutes()
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
routes := make(map[netip.Prefix]systemops.Route, len(entries))
for _, entry := range entries {
routes[entry.Destination] = entry
}
return routes, nil
return entries, nil
}
func stateFromInt(state uint8) string {
@@ -203,14 +237,18 @@ func stateFromInt(state uint8) string {
}
func compareIntf(a, b *net.Interface) int {
if a == nil && b == nil {
switch {
case a == nil && b == nil:
return 0
}
if a == nil {
case a == nil:
return -1
}
if b == nil {
case b == nil:
return 1
default:
return a.Index - b.Index
}
return a.Index - b.Index
}
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
}

View File

@@ -36,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -51,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -88,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -124,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -154,7 +154,7 @@ func TestConn_Status(t *testing.T) {
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()

View File

@@ -3,6 +3,7 @@ package routemanager
import (
"context"
"fmt"
"reflect"
"time"
"github.com/hashicorp/go-multierror"
@@ -309,22 +310,33 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
}()
}
func (c *clientNetwork) handleUpdate(update routesUpdate) {
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes {
updateMap[r.ID] = r
}
if len(c.routes) != len(updateMap) {
isUpdateMapDifferent = true
}
for id, r := range c.routes {
_, found := updateMap[id]
if !found {
close(c.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true
continue
}
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
isUpdateMapDifferent = true
}
}
c.routes = updateMap
return isUpdateMapDifferent
}
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
@@ -351,13 +363,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
log.Debugf("Received a new client network route update for [%v]", c.handler)
c.handleUpdate(update)
// hash update somehow
isTrueRouteUpdate := c.handleUpdate(update)
c.updateSerial = update.updateSerial
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
if isTrueRouteUpdate {
log.Debug("Client network update contains different routes, recalculating routes")
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
} else {
log.Debug("Route update is not different, skipping route recalculation")
}
c.startPeersStatusChangeWatcher()

View File

@@ -23,7 +23,8 @@ import (
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
@@ -160,7 +161,12 @@ func (r *Route) startResolver(ctx context.Context) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
r.update(ctx)
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for {
select {
@@ -168,17 +174,28 @@ func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
r.update(ctx)
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
}
}
}
func (r *Route) update(ctx context.Context) {
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
log.Errorf("Failed to update dynamic routes for [%v]: %v", r, err)
return fmt.Errorf("update dynamic routes: %w", err)
}
return nil
}
func (r *Route) resolveDomains() (domainMap, error) {

View File

@@ -16,6 +16,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -50,7 +51,7 @@ type DefaultManager struct {
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
@@ -65,7 +66,8 @@ func NewManager(
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)
dm := &DefaultManager{
ctx: mCTX,
@@ -77,7 +79,7 @@ func NewManager(
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
notifier: notifier,
}
dm.routeRefCounter = refcounter.New(
@@ -107,7 +109,7 @@ func NewManager(
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
@@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
@@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
}
}
// SetRouteChangeListener set RouteListener for route change notifier
// SetRouteChangeListener set RouteListener for route change Notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.setListener(listener)
m.notifier.SetListener(listener)
}
// InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.getInitialRouteRanges()
return m.notifier.GetInitialRouteRanges()
}
// GetRouteSelector returns the route selector
@@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
networks = m.routeSelector.FilterSelected(networks)
m.notifier.onNewRoutes(networks)
m.notifier.OnNewRoutes(networks)
m.stopObsoleteClients(networks)

View File

@@ -1,6 +1,7 @@
package routemanager
package notifier
import (
"net/netip"
"runtime"
"sort"
"strings"
@@ -10,7 +11,7 @@ import (
"github.com/netbirdio/netbird/route"
)
type notifier struct {
type Notifier struct {
initialRouteRanges []string
routeRanges []string
@@ -18,17 +19,17 @@ type notifier struct {
listenerMux sync.Mutex
}
func newNotifier() *notifier {
return &notifier{}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
nets = append(nets, r.Network.String())
@@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRanges = nets
}
func (n *notifier) onNewRoutes(idMap route.HAMap) {
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
@@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
n.notify()
}
func (n *notifier) notify() {
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
@@ -74,7 +101,7 @@ func (n *notifier) notify() {
}(n.listener)
}
func (n *notifier) hasDiff(a []string, b []string) bool {
func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
@@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false
}
func (n *notifier) getInitialRouteRanges() []string {
func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}

View File

@@ -3,7 +3,9 @@ package systemops
import (
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
@@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
prefixes map[netip.Prefix]struct{}
//nolint
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}

View File

@@ -1,4 +1,4 @@
//go:build ios || android
//go:build android
package systemops

View File

@@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
r := NewSysOps(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {

View File

@@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
// These errors are not critical, but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
@@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
@@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return exitNextHop, nil
}
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
localInterfaces, err := net.Interfaces()
if err != nil {
log.Errorf("Failed to get local interfaces: %v", err)
return false, nil
}
for _, intf := range localInterfaces {
addrs, err := intf.Addrs()
if err != nil {
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
log.Errorf("Failed to convert address to IPNet: %v", addr)
continue
}
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
}
}
return false, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {

View File

@@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err)
@@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
@@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})
r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {

View File

@@ -0,0 +1,64 @@
//go:build ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
r.notify()
return nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes[prefix] = struct{}{}
r.notify()
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.prefixes, prefix)
r.notify()
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}

View File

@@ -73,7 +73,7 @@ var testCases = []testCase{
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
@@ -110,7 +110,7 @@ var testCases = []testCase{
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
@@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
return combinedOutput
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
@@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
}
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput
@@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
addDummyRoute(t, "10.0.0.0/8")
}
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
}
})
}

View File

@@ -8,9 +8,13 @@ import (
log "github.com/sirupsen/logrus"
)
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
if err != nil {

View File

@@ -4,6 +4,6 @@ package wgproxy
import "context"
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -47,6 +48,7 @@ type CustomLogger interface {
type selectRoute struct {
NetID string
Network netip.Prefix
Domains domain.List
Selected bool
}
@@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
route := &selectRoute{
NetID: string(id),
Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
@@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return iPrefix < jPrefix
})
resolvedDomains := c.recorder.GetResolvedDomainsStates()
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
}
domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Domains: &domainDetails,
Selected: r.Selected,
})
}
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails, nil
return &routeSelectionDetails
}
func (c *Client) SelectRoute(id string) error {

View File

@@ -16,9 +16,25 @@ type RoutesSelectionDetails struct {
type RoutesSelectionInfo struct {
ID string
Network string
Domains *DomainDetails
Selected bool
}
type DomainCollection interface {
Add(s DomainInfo) DomainCollection
Get(i int) *DomainInfo
Size() int
}
type DomainDetails struct {
items []DomainInfo
}
type DomainInfo struct {
Domain string
ResolvedIPs string
}
// Add new PeerInfo to the collection
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
array.items = append(array.items, s)
@@ -34,3 +50,16 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
func (array RoutesSelectionDetails) Size() int {
return len(array.items)
}
func (array DomainDetails) Add(s DomainInfo) DomainCollection {
array.items = append(array.items, s)
return array
}
func (array DomainDetails) Get(i int) *DomainInfo {
return &array.items[i]
}
func (array DomainDetails) Size() int {
return len(array.items)
}

View File

@@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
// Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
}
// Status returns the daemon status

View File

@@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
@@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -8,6 +8,7 @@ import (
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
@@ -89,5 +90,17 @@ func _getInfo() string {
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo
si.GetSysInfo()
return si.Chassis.Serial, si.Product.Name, si.Product.Vendor
isascii := regexp.MustCompile("^[[:ascii:]]+$")
serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial
}
if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
serial = si.Board.Serial
}
name := si.Product.Name
if (!isascii.MatchString(name)) && si.Board.Name != "" {
name = si.Board.Name
}
return serial, name, si.Product.Vendor
}

View File

@@ -15,7 +15,6 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
"unicode"
@@ -34,6 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
@@ -62,8 +62,25 @@ func main() {
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
var saveLogsInFile bool
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
flag.Parse()
if saveLogsInFile {
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
err := util.InitLog("trace", logFile)
if err != nil {
log.Errorf("error while initializing log: %v", err)
return
}
}
a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
@@ -76,10 +93,15 @@ func main() {
if showSettings || showRoutes {
a.Run()
} else {
if err := checkPIDFile(); err != nil {
log.Errorf("check PID file: %v", err)
running, err := isAnotherProcessRunning()
if err != nil {
log.Errorf("error while checking process: %v", err)
}
if running {
log.Warn("another process is running")
return
}
client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit)
}
}
@@ -860,19 +882,3 @@ func openURL(url string) error {
}
return err
}
// checkPIDFile exists and return error, or write new.
func checkPIDFile() error {
pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid")
if piddata, err := os.ReadFile(pidFile); err == nil {
if pid, err := strconv.Atoi(string(piddata)); err == nil {
if process, err := os.FindProcess(pid); err == nil {
if err := process.Signal(syscall.Signal(0)); err == nil {
return fmt.Errorf("process already exists: %d", pid)
}
}
}
}
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}

26
client/ui/font_bsd.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package main
import (
"os"
"runtime"
log "github.com/sirupsen/logrus"
)
const defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
func (s *serviceClient) setDefaultFonts() {
// TODO: add other bsd paths
if runtime.GOOS != "darwin" {
return
}
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}

7
client/ui/font_linux.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build !386
package main
func (s *serviceClient) setDefaultFonts() {
//TODO: Linux Multiple Language Support
}

91
client/ui/font_windows.go Normal file
View File

@@ -0,0 +1,91 @@
package main
import (
"os"
"path"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
func (s *serviceClient) setDefaultFonts() {
defaultFontPath := s.getWindowsFontFilePath()
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}
func (s *serviceClient) getWindowsFontFilePath() string {
var (
fontFolder = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
// getUserDefaultLocaleName.Call() panics if the func is not found
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %v", r)
}
}()
kernel32 := windows.NewLazySystemDLL("kernel32.dll")
getUserDefaultLocaleName := kernel32.NewProc("GetUserDefaultLocaleName")
buf := make([]uint16, 85) // LOCALE_NAME_MAX_LENGTH is usually 85
r, _, err := getUserDefaultLocaleName.Call(uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)))
// returns 0 on failure, err is always non-nil
// https://learn.microsoft.com/en-us/windows/win32/api/winnls/nf-winnls-getuserdefaultlocalename
if r == 0 {
log.Errorf("GetUserDefaultLocaleName call failed: %v", err)
return path.Join(fontFolder, fontMapping["default"])
}
defaultLanguage := windows.UTF16ToString(buf)
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
return path.Join(fontFolder, fontMapping["nirmala"])
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
return path.Join(fontFolder, font)
}
return path.Join(fontFolder, fontMapping["default"])
}

37
client/ui/process.go Normal file
View File

@@ -0,0 +1,37 @@
package main
import (
"os"
"path/filepath"
"strings"
"github.com/shirou/gopsutil/v3/process"
)
func isAnotherProcessRunning() (bool, error) {
processes, err := process.Processes()
if err != nil {
return false, err
}
pid := os.Getpid()
processName := strings.ToLower(filepath.Base(os.Args[0]))
for _, p := range processes {
if int(p.Pid) == pid {
continue
}
runningProcessPath, err := p.Exe()
// most errors are related to short-lived processes
if err != nil {
continue
}
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
return true, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,26 @@
//go:build !windows
package main
import (
"os"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
currentUserID := os.Getuid()
uids, err := p.Uids()
if err != nil {
log.Errorf("get process uids: %v", err)
return false
}
for _, id := range uids {
log.Debugf("checking process uid: %d", id)
if int(id) == currentUserID {
return true
}
}
return false
}

View File

@@ -0,0 +1,24 @@
package main
import (
"os/user"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
processUsername, err := p.Username()
if err != nil {
log.Errorf("get process username error: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Errorf("get current user error: %v", err)
return false
}
return processUsername == currUser.Username
}

View File

@@ -4,6 +4,7 @@ package main
import (
"fmt"
"sort"
"strings"
"time"
@@ -17,28 +18,57 @@ import (
"github.com/netbirdio/netbird/client/proto"
)
const (
allRoutesText = "All routes"
overlappingRoutesText = "Overlapping routes"
exitNodeRoutesText = "Exit-node routes"
allRoutes filter = "all"
overlappingRoutes filter = "overlapping"
exitNodeRoutes filter = "exit-node"
getClientFMT = "get client: %v"
)
type filter string
func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(grid)
allGrid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(allGrid, allRoutes)
overlappingGrid := container.New(layout.NewGridLayout(3))
exitNodeGrid := container.New(layout.NewGridLayout(3))
routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid)
tabs := container.NewAppTabs(
container.NewTabItem(allRoutesText, allGrid),
container.NewTabItem(overlappingRoutesText, overlappingGrid),
container.NewTabItem(exitNodeRoutesText, exitNodeGrid),
)
tabs.OnSelected = func(item *container.TabItem) {
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}
tabs.OnUnselected = func(item *container.TabItem) {
grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
grid.Objects = nil
}
routeCheckContainer.Add(tabs)
scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox(
layout.NewSpacer(),
widget.NewButton("Refresh", func() {
s.updateRoutes(grid)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}),
widget.NewButton("Select all", func() {
s.selectAllRoutes()
s.updateRoutes(grid)
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.selectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}),
widget.NewButton("Deselect All", func() {
s.deselectAllRoutes()
s.updateRoutes(grid)
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.deselectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}),
layout.NewSpacer(),
)
@@ -48,18 +78,12 @@ func (s *serviceClient) showRoutesUI() {
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(5*time.Second, grid)
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
}
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
grid.Objects = nil
grid.Refresh()
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
@@ -67,7 +91,15 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
grid.Add(idHeader)
grid.Add(networkHeader)
grid.Add(resolvedIPsHeader)
for _, route := range routes {
filteredRoutes, err := s.getFilteredRoutes(f)
if err != nil {
return
}
sortRoutesByIDs(filteredRoutes)
for _, route := range filteredRoutes {
r := route
checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
@@ -80,35 +112,104 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
grid.Add(checkBox)
network := r.GetNetwork()
domains := r.GetDomains()
if len(domains) > 0 {
network = strings.Join(domains, ", ")
}
grid.Add(widget.NewLabel(network))
if len(domains) > 0 {
var resolvedIPsList []string
for _, domain := range r.GetDomains() {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
// TODO: limit width
resolvedIPsLabel := widget.NewLabel(strings.Join(resolvedIPsList, ", "))
grid.Add(resolvedIPsLabel)
} else {
if len(domains) == 0 {
grid.Add(widget.NewLabel(network))
grid.Add(widget.NewLabel(""))
continue
}
// our selectors are only for display
noopFunc := func(_ string) {
// do nothing
}
domainsSelector := widget.NewSelect(domains, noopFunc)
domainsSelector.Selected = domains[0]
grid.Add(domainsSelector)
var resolvedIPsList []string
for _, domain := range domains {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
if len(resolvedIPsList) == 0 {
grid.Add(widget.NewLabel(""))
continue
}
// TODO: limit width within the selector display
resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc)
resolvedIPsSelector.Selected = resolvedIPsList[0]
resolvedIPsSelector.Resize(fyne.NewSize(100, 100))
grid.Add(resolvedIPsSelector)
}
s.wRoutes.Content().Refresh()
grid.Refresh()
}
func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
return nil, err
}
switch f {
case overlappingRoutes:
return getOverlappingRoutes(routes), nil
case exitNodeRoutes:
return getExitNodeRoutes(routes), nil
default:
}
return routes, nil
}
func getOverlappingRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
existingRange := make(map[string][]*proto.Route)
for _, route := range routes {
if len(route.Domains) > 0 {
continue
}
if r, exists := existingRange[route.GetNetwork()]; exists {
r = append(r, route)
existingRange[route.GetNetwork()] = r
} else {
existingRange[route.GetNetwork()] = []*proto.Route{route}
}
}
for _, r := range existingRange {
if len(r) > 1 {
filteredRoutes = append(filteredRoutes, r...)
}
}
return filteredRoutes
}
func getExitNodeRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
for _, route := range routes {
if route.Network == "0.0.0.0/0" {
filteredRoutes = append(filteredRoutes, route)
}
}
return filteredRoutes
}
func sortRoutesByIDs(routes []*proto.Route) {
sort.Slice(routes, func(i, j int) bool {
return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID())
})
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
return nil, fmt.Errorf(getClientFMT, err)
}
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
@@ -122,8 +223,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
return
}
@@ -149,16 +250,14 @@ func (s *serviceClient) selectRoute(id string, checked bool) {
}
}
func (s *serviceClient) selectAllRoutes() {
func (s *serviceClient) selectAllFilteredRoutes(f filter) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
log.Errorf(getClientFMT, err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
req := s.getRoutesRequest(f, true)
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err))
@@ -168,16 +267,14 @@ func (s *serviceClient) selectAllRoutes() {
log.Debug("All routes selected")
}
func (s *serviceClient) deselectAllRoutes() {
func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
log.Errorf(getClientFMT, err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
req := s.getRoutesRequest(f, false)
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
@@ -187,17 +284,34 @@ func (s *serviceClient) deselectAllRoutes() {
log.Debug("All routes deselected")
}
func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest {
req := &proto.SelectRoutesRequest{}
if f == allRoutes {
req.All = true
} else {
routes, err := s.getFilteredRoutes(f)
if err != nil {
return nil
}
for _, route := range routes {
req.RouteIDs = append(req.RouteIDs, route.GetID())
}
req.Append = appendRoute
}
return req
}
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
s.updateRoutes(grid)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
}
}()
@@ -206,6 +320,23 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Cont
})
}
func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.wRoutes.Content().Refresh()
s.updateRoutes(grid, f)
}
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
switch tabs.Selected().Text {
case overlappingRoutesText:
return overlappingGrid, overlappingRoutes
case exitNodeRoutesText:
return exitNodesGrid, exitNodeRoutes
default:
return allGrid, allRoutes
}
}
// wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string {

View File

@@ -7,6 +7,18 @@ import (
"strings"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/context"
)
type ExecutionContext string
const (
ExecutionContextKey = "executionContext"
HTTPSource ExecutionContext = "HTTP"
GRPCSource ExecutionContext = "GRPC"
SystemSource ExecutionContext = "SYSTEM"
)
// ContextHook is a custom hook for add the source information for the entry
@@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level {
func (hook ContextHook) Fire(entry *logrus.Entry) error {
src := hook.parseSrc(entry.Caller.File)
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
if entry.Context == nil {
return nil
}
source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext)
if !ok {
return nil
}
entry.Data["context"] = source
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
return nil
}
@@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string {
file := path.Base(filePath)
return fmt.Sprintf("%s/%s", pkg, file)
}
func addHTTPFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}

View File

@@ -1,6 +1,8 @@
package formatter
import "github.com/sirupsen/logrus"
import (
"github.com/sirupsen/logrus"
)
// SetTextFormatter set the text formatter for given logger.
func SetTextFormatter(logger *logrus.Logger) {
@@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.AddHook(NewContextHook())
}
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {
logger.Formatter = &logrus.JSONFormatter{}
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetLogcatFormatter set the logcat formatter for given logger.
func SetLogcatFormatter(logger *logrus.Logger) {
logger.Formatter = NewLogcatFormatter()

17
go.mod
View File

@@ -19,12 +19,12 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.23.0
golang.org/x/sys v0.20.0
golang.org/x/crypto v0.24.0
golang.org/x/sys v0.21.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.0
google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -44,7 +44,6 @@ require (
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
@@ -58,7 +57,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -84,10 +83,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.25.0
golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0
golang.org/x/term v0.20.0
golang.org/x/term v0.21.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7
@@ -189,8 +188,8 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/image v0.10.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect

36
go.sum
View File

@@ -209,8 +209,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
@@ -335,8 +333,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
@@ -534,16 +532,16 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M=
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
@@ -565,7 +563,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -580,8 +577,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
@@ -637,8 +634,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -646,8 +643,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@@ -655,11 +652,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -705,8 +701,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@@ -64,7 +64,7 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
t.wrapper = newDeviceWrapper(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -49,7 +49,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()

View File

@@ -64,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.wrapper = newDeviceWrapper(tunDevice)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -54,7 +54,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = newWGUSPConfigurer(t.device, t.name)

View File

@@ -57,7 +57,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()

View File

@@ -41,6 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, err
@@ -52,7 +53,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(wgLogLevel(), "[netbird] "),
)
luid := winipcfg.LUID(t.nativeTunDevice.LUID())

15
iface/wg_log.go Normal file
View File

@@ -0,0 +1,15 @@
package iface
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@@ -28,7 +28,11 @@ services:
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal
signal:
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
@@ -40,6 +44,11 @@ services:
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management
management:
@@ -63,12 +72,16 @@ services:
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn
coturn:
image: coturn/coturn:$COTURN_TAG
restart: unless-stopped
domainname: $TURN_DOMAIN
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
@@ -76,7 +89,11 @@ services:
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
$MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME:

View File

@@ -50,7 +50,7 @@ check_jq() {
wait_crdb() {
set +e
while true; do
if $DOCKER_COMPOSE_COMMAND exec -T crdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
if $DOCKER_COMPOSE_COMMAND exec -T zdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
break
fi
echo -n " ."
@@ -61,14 +61,16 @@ wait_crdb() {
}
init_crdb() {
echo -e "\nInitializing Zitadel's CockroachDB\n\n"
$DOCKER_COMPOSE_COMMAND up -d crdb
echo ""
# shellcheck disable=SC2028
echo -n "Waiting cockroachDB to become ready "
wait_crdb
$DOCKER_COMPOSE_COMMAND exec -T crdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
handle_request_command_status $? "init_crdb failed" ""
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
echo -e "\nInitializing Zitadel's CockroachDB\n\n"
$DOCKER_COMPOSE_COMMAND up -d zdb
echo ""
# shellcheck disable=SC2028
echo -n "Waiting CockroachDB to become ready"
wait_crdb
$DOCKER_COMPOSE_COMMAND exec -T zdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
handle_request_command_status $? "init_crdb failed" ""
fi
}
get_main_ip_address() {
@@ -156,7 +158,7 @@ create_new_application() {
"'"$BASE_REDIRECT_URL2"'"
],
"postLogoutRedirectUris": [
"'"$LOGOUT_URL"'"
"'"$LOGOUT_URL"'"
],
"RESPONSETypes": [
"OIDC_RESPONSE_TYPE_CODE"
@@ -461,6 +463,20 @@ initEnvironment() {
exit 1
fi
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
echo "Use CockroachDB as Zitadel database."
ZDB=$(renderDockerComposeCockroachDB)
ZITADEL_DB_ENV=$(renderZitadelCockroachDBEnv)
else
echo "Use Postgres as default Zitadel database."
echo "For using CockroachDB please the environment variable 'export ZITADEL_DATABASE=cockroach'."
POSTGRES_ROOT_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
POSTGRES_ZITADEL_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
ZDB=$(renderDockerComposePostgres)
ZITADEL_DB_ENV=$(renderZitadelPostgresEnv)
renderPostgresEnv > zdb.env
fi
echo Rendering initial files...
renderDockerCompose > docker-compose.yml
renderCaddyfile > Caddyfile
@@ -474,7 +490,7 @@ initEnvironment() {
init_crdb
echo -e "\nStarting Zidatel IDP for user management\n\n"
echo -e "\nStarting Zitadel IDP for user management\n\n"
$DOCKER_COMPOSE_COMMAND up -d caddy zitadel
init_zitadel
@@ -634,15 +650,15 @@ renderManagementJson() {
"ExtraConfig": {
"ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1"
}
},
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid"
}
},
},
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid"
}
},
"PKCEAuthorizationFlow": {
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
@@ -679,16 +695,6 @@ renderZitadelEnv() {
cat <<EOF
ZITADEL_LOG_LEVEL=debug
ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
ZITADEL_DATABASE_COCKROACH_HOST=crdb
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/crdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/crdb-certs/client.zitadel_user.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/crdb-certs/client.zitadel_user.key"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/crdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/crdb-certs/client.root.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/crdb-certs/client.root.key"
ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
ZITADEL_TLS_ENABLED="false"
ZITADEL_EXTERNALPORT=$NETBIRD_PORT
@@ -698,6 +704,43 @@ ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=zitadel-admin-sa
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=Admin
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_SCOPES=openid
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZIDATE_TOKEN_EXPIRATION_DATE
$ZITADEL_DB_ENV
EOF
}
renderZitadelCockroachDBEnv() {
cat <<EOF
ZITADEL_DATABASE_COCKROACH_HOST=zdb
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/zdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/zdb-certs/client.zitadel_user.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/zdb-certs/client.zitadel_user.key"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/zdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/zdb-certs/client.root.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/zdb-certs/client.root.key"
EOF
}
renderZitadelPostgresEnv() {
cat <<EOF
ZITADEL_DATABASE_POSTGRES_HOST=zdb
ZITADEL_DATABASE_POSTGRES_PORT=5432
ZITADEL_DATABASE_POSTGRES_DATABASE=zitadel
ZITADEL_DATABASE_POSTGRES_USER_USERNAME=zitadel
ZITADEL_DATABASE_POSTGRES_USER_PASSWORD=$POSTGRES_ZITADEL_PASSWORD
ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE=disable
ZITADEL_DATABASE_POSTGRES_ADMIN_USERNAME=root
ZITADEL_DATABASE_POSTGRES_ADMIN_PASSWORD=$POSTGRES_ROOT_PASSWORD
ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_MODE=disable
EOF
}
renderPostgresEnv() {
cat <<EOF
POSTGRES_USER=root
POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD
EOF
}
@@ -717,18 +760,28 @@ services:
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile
#UI dashboard
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
restart: unless-stopped
networks: [netbird]
env_file:
- ./dashboard.env
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal
signal:
image: netbirdio/signal:latest
restart: unless-stopped
networks: [netbird]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management
management:
image: netbirdio/management:latest
@@ -746,52 +799,49 @@ services:
"--dns-domain=netbird.selfhosted",
"--idp-sign-key-refresh-enabled",
]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn, AKA relay server
coturn:
image: coturn/coturn
restart: unless-stopped
domainname: netbird.relay.selfhosted
#domainname: netbird.relay.selfhosted
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host
command:
- -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Zitadel - identity provider
zitadel:
restart: 'always'
networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.31.3'
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file:
- ./zitadel.env
depends_on:
crdb:
zdb:
condition: 'service_healthy'
volumes:
- ./machinekey:/machinekey
- netbird_zitadel_certs:/crdb-certs:ro
# CockroachDB for zitadel
crdb:
restart: 'always'
networks: [netbird]
image: 'cockroachdb/cockroach:v22.2.2'
command: 'start-single-node --advertise-addr crdb'
volumes:
- netbird_crdb_data:/cockroach/cockroach-data
- netbird_crdb_certs:/cockroach/certs
- netbird_zitadel_certs:/zitadel-certs
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
interval: '10s'
timeout: '30s'
retries: 5
start_period: '20s'
volumes:
- netbird_zitadel_certs:/zdb-certs:ro
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
$ZDB
netbird_zdb_data:
netbird_management:
netbird_caddy_data:
netbird_crdb_data:
netbird_crdb_certs:
netbird_zitadel_certs:
networks:
@@ -799,4 +849,59 @@ networks:
EOF
}
renderDockerComposeCockroachDB() {
cat <<EOF
# CockroachDB for Zitadel
zdb:
restart: 'always'
networks: [netbird]
image: 'cockroachdb/cockroach:latest-v23.2'
command: 'start-single-node --advertise-addr zdb'
volumes:
- netbird_zdb_data:/cockroach/cockroach-data
- netbird_zdb_certs:/cockroach/certs
- netbird_zitadel_certs:/zitadel-certs
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
interval: '10s'
timeout: '30s'
retries: 5
start_period: '20s'
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_zdb_certs:
EOF
}
renderDockerComposePostgres() {
cat <<EOF
# Postgres for Zitadel
zdb:
restart: 'always'
networks: [netbird]
image: 'postgres:16-alpine'
env_file:
- ./zdb.env
volumes:
- netbird_zdb_data:/var/lib/postgresql/data:rw
healthcheck:
test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"]
interval: 5s
timeout: 60s
retries: 10
start_period: 5s
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
EOF
}
initEnvironment

View File

@@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
t.Fatal(err)
}
@@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
@@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
@@ -51,26 +46,21 @@ type GrpcClient struct {
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
var conn *grpc.ClientConn
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
}
@@ -326,25 +316,44 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
}

View File

@@ -20,6 +20,7 @@ import (
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -35,8 +36,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
@@ -77,6 +80,10 @@ var (
Short: "start NetBird Management Server",
PreRunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
@@ -85,7 +92,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(mgmtConfig)
config, err = loadMgmtConfig(ctx, mgmtConfig)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
@@ -116,6 +123,11 @@ var (
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource)
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("failed to migrate files %v", err)
@@ -131,11 +143,11 @@ var (
if err != nil {
return err
}
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics)
store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
@@ -143,7 +155,7 @@ var (
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
@@ -152,32 +164,32 @@ var (
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey)
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return fmt.Errorf("failed to initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key {
log.Infof("update config with activity store key")
log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config)
err := updateMgmtConfig(ctx, mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
geo, err := geolocation.NewGeolocation(config.Datadir)
geo, err := geolocation.NewGeolocation(ctx, config.Datadir)
if err != nil {
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
} else {
log.Infof("geo location service has been initialized from %s", config.Datadir)
log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir)
}
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore)
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
@@ -188,13 +200,13 @@ var (
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
@@ -206,8 +218,8 @@ var (
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
var certManager *autocert.Manager
@@ -224,7 +236,7 @@ var (
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
transportCredentials := credentials.NewTLS(tlsConfig)
@@ -233,6 +245,7 @@ var (
}
jwtValidator, err := jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
@@ -249,26 +262,24 @@ var (
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers()
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(store)
installationID, err := getInstallationID(ctx, store)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
@@ -278,18 +289,18 @@ var (
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run()
go metricsWorker.Run(ctx)
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort)
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
if err != nil {
return err
}
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
@@ -306,8 +317,8 @@ var (
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(cml, certManager.HTTPHandler(nil))
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
}
} else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
@@ -321,14 +332,14 @@ var (
}
}
log.Infof("management server version %s", version.NetbirdVersion())
log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(listener, rootHandler, tlsEnabled)
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
SetupCloseHandler()
<-stopCh
integratedPeerValidator.Stop()
integratedPeerValidator.Stop(ctx)
if geo != nil {
_ = geo.Stop()
}
@@ -339,39 +350,68 @@ var (
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close()
_ = eventStore.Close()
log.Infof("stopped Management Service")
_ = store.Close(ctx)
_ = eventStore.Close(ctx)
log.WithContext(ctx).Infof("stopped Management Service")
return nil
},
}
)
func notifyStop(msg string) {
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
select {
case stopCh <- 1:
log.Error(msg)
log.WithContext(ctx).Error(msg)
default:
// stop has been already called, nothing to report
}
}
func getInstallationID(store server.Store) (string, error) {
func getInstallationID(ctx context.Context, store server.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(installationID)
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
@@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
go func() {
err := grpcServer.Serve(listener)
if err != nil {
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
}
}()
return listener, nil
}
func serveHTTP(httpListener net.Listener, handler http.Handler) {
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
go func() {
err := http.Serve(httpListener, handler)
if err != nil {
notifyStop(fmt.Sprintf("failed running HTTP server: %v", err))
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
}
}()
}
func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) {
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() {
var err error
if tlsEnabled {
@@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b
if err != nil {
select {
case stopCh <- 1:
log.Errorf("failed to serve HTTP and gRPC server: %v", err)
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
default:
// stop has been already called, nothing to report
}
@@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
})
}
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
if err != nil {
@@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint != "" {
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(oidcEndpoint)
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
if err != nil {
return nil, err
}
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
@@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
if err != nil {
return nil, err
}
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
@@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
}
if loadedConfig.PKCEAuthorizationFlow != nil {
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
@@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err
}
func updateMgmtConfig(path string, config *server.Config) error {
return util.DirectWriteJson(path, config)
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
// OIDCConfigResponse used for parsing OIDC config response
@@ -515,7 +555,7 @@ type OIDCConfigResponse struct {
}
// fetchOIDCConfig fetches OIDC configuration from the IDP
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint)
if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
@@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
defer func() {
err := res.Body.Close()
if err != nil {
log.Debugf("failed closing response body %v", err)
log.WithContext(ctx).Debugf("failed closing response body %v", err)
}
}()

View File

@@ -1,13 +1,16 @@
package cmd
import (
"context"
"flag"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
)
var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command."
@@ -26,10 +29,13 @@ var upCmd = &cobra.Command{
return fmt.Errorf("failed initializing log %v", err)
}
if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil {
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
return err
}
log.Info("Migration finished successfully")
log.WithContext(ctx).Info("Migration finished successfully")
return nil
},

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"encoding/json"
@@ -29,11 +30,11 @@ import (
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
@@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
@@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
}
func (MocIntegratedValidator) Stop() {
func (MocIntegratedValidator) Stop(_ context.Context) {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
@@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
setupKey = key.Key
}
_, _, err := manager.AddPeer(setupKey, userID, peer)
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId("account-1", userID, "netbird.io")
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{}
}
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers)
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(accountID, userId, domain)
account := newAccountWithId(context.Background(), accountID, userId, domain)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
@@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err = manager.Store.GetAccountByUser(userID)
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return
@@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId("", userId, domain)
initAccount := newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
@@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(claims)
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(initAccount)
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims)
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(initAccount)
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims)
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
@@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
@@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
},
},
}
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}
account, user, pat, err := am.GetAccountFromPAT(token)
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
@@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
@@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
},
},
}
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
Store: store,
}
err = am.MarkPATUsed("tokenId")
err = am.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount("account_id")
account, err = am.Store.GetAccount(context.Background(), "account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
@@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(userId, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
if err != nil {
t.Fatal(err)
}
@@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.Store.GetAccountByUser(userId)
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(userId, domain)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil {
t.Fatal(err)
}
@@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(userId, domain)
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
@@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
@@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return
}
_, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
}
_, err = manager.GetAccountByUserOrAccountID("", "", "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(accountID, userID, domain)
err := am.Store.SaveAccount(account)
account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
@@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
// AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.Store.GetAccount(account.Id)
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(err)
}
err = manager.DeleteAccount(account.Id, userId)
err = manager.DeleteAccount(context.Background(), account.Id, userId)
if err != nil {
t.Fatal(err)
}
getAccount, err := manager.Store.GetAccount(account.Id)
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
if err == nil {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
}
@@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud")
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
if err != nil {
t.Fatal(err)
}
@@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer()
peer3 := getPeer()
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
}
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID)
defer manager.peersUpdateManager.CloseChannel(peer1.ID)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{
ID: "group-id",
@@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SaveGroup(account.Id, userID, &group); err != nil {
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil {
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SavePolicy(account.Id, userID, &policy); err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil {
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err)
return
}
@@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}()
// clean policy is pre requirement for delete group
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
@@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
})
@@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
err = manager.DeletePeer(account.Id, peerKey, userID)
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
if err != nil {
return
}
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType
case <-time.After(time.Second):
t.Fatal("no PeerAddedWithSetupKey event was generated")
default:
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
t.Fatal(err)
}
@@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user
}
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1")
if err != nil {
t.Fatal(err)
}
@@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
assert.Len(t, emptyRoutes, 0)
}
@@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings)
@@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
@@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
@@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
@@ -2294,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
}
eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, err
}
@@ -2305,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
func createStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir)
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@@ -86,7 +87,7 @@ type Store struct {
}
// NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
if err != nil {
@@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return nil, err
}
err = updateDeletedUsersTable(db)
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
@@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return s, nil
}
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0)
var cryptErr error
for result.Next() {
@@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
}
if cryptErr != nil {
log.Warnf("%s", cryptErr)
log.WithContext(ctx).Warnf("%s", cryptErr)
}
return events, nil
}
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
stmt := store.selectDescStatement
if !descending {
stmt = store.selectAscStatement
@@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
}
defer result.Close() //nolint
return store.processResult(result)
return store.processResult(ctx, result)
}
// Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
var jsonMeta string
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
if err != nil {
@@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
}
// Close the Store
func (store *Store) Close() error {
func (store *Store) Close(_ context.Context) error {
if store.db != nil {
return store.db.Close()
}
return nil
}
func updateDeletedUsersTable(db *sql.DB) error {
log.Debugf("check deleted_users table version")
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
log.WithContext(ctx).Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
if err != nil {
return err
@@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error {
return nil
}
log.Debugf("update delted_users table")
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
}

View File

@@ -1,6 +1,7 @@
package sqlite
import (
"context"
"fmt"
"testing"
"time"
@@ -13,17 +14,17 @@ import (
func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir()
key, _ := GenerateKey()
store, err := NewSQLiteStore(dataDir, key)
store, err := NewSQLiteStore(context.Background(), dataDir, key)
if err != nil {
t.Fatal(err)
return
}
defer store.Close() //nolint
defer store.Close(context.Background()) //nolint
accountID := "account_1"
for i := 0; i < 10; i++ {
_, err = store.Save(&activity.Event{
_, err = store.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
@@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) {
}
}
result, err := store.Get(accountID, 0, 10, false)
result, err := store.Get(context.Background(), accountID, 0, 10, false)
if err != nil {
t.Fatal(err)
return
@@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) {
assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
result, err = store.Get(accountID, 0, 5, true)
result, err = store.Get(context.Background(), accountID, 0, 5, true)
if err != nil {
t.Fatal(err)
return

View File

@@ -1,15 +1,18 @@
package activity
import "sync"
import (
"context"
"sync"
)
// Store provides an interface to store or stream events.
type Store interface {
// Save an event in the store
Save(event *Event) (*Event, error)
Save(ctx context.Context, event *Event) (*Event, error)
// Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp
Get(accountID string, offset, limit int, descending bool) ([]*Event, error)
Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error)
// Close the sink flushing events if necessary
Close() error
Close(ctx context.Context) error
}
// InMemoryEventStore implements the Store interface storing data in-memory
@@ -20,7 +23,7 @@ type InMemoryEventStore struct {
}
// Save sets the Event.ID to 1
func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) {
store.mu.Lock()
defer store.mu.Unlock()
if store.events == nil {
@@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
}
// Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration
func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) {
func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) {
store.mu.Lock()
defer store.mu.Unlock()
events := make([]*Event, 0)
@@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen
}
// Close cleans up the event list
func (store *InMemoryEventStore) Close() error {
func (store *InMemoryEventStore) Close(_ context.Context) error {
store.mu.Lock()
defer store.mu.Unlock()
store.events = make([]*Event, 0)

View File

@@ -0,0 +1,8 @@
package context
const (
RequestIDKey = "requestID"
AccountIDKey = "accountID"
UserIDKey = "userID"
PeerIDKey = "peerID"
)

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"strconv"
@@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings {
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
account.DNSSettings = dnsSettingsToSave.Copy()
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
@@ -97,17 +98,17 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
@@ -149,9 +150,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
return protoUpdate
}
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.Errorf("no dns domain is set, returning empty zone")
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
@@ -161,7 +162,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
@@ -210,14 +211,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
return false
}
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
continue
}
}

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"net/netip"
"testing"
@@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("failed to init testing account")
}
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) {
DisabledManagementGroups: []string{group1ID},
}
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID)
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
}
_, err = am.GetDNSSettings(account.Id, dnsRegularUserID)
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
@@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to init testing account")
}
err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings)
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
if err != nil {
if testCase.shouldFail {
return
@@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err)
}
updatedAccount, err := am.Store.GetAccount(account.Id)
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err)
}
@@ -164,7 +165,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to init testing account")
}
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
@@ -173,14 +174,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID)
require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
@@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
}
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir)
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil {
return nil, err
}
@@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com"
account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain)
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account.Users[dnsRegularUserID] = &User{
Id: dnsRegularUserID,
Role: UserRoleUser,
}
err := am.Store.SaveAccount(account)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil {
return nil, err
}
_, _, err = am.AddPeer("", dnsAdminUserID, peer2)
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil {
return nil, err
}
account, err = am.Store.GetAccount(account.Id)
account, err = am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, err
}
@@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Groups: []string{allGroup.ID},
}
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
return am.Store.GetAccount(account.Id)
return am.Store.GetAccount(context.Background(), account.Id)
}

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"sync"
"time"
@@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
func (e *EphemeralManager) LoadInitialPeers() {
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers()
e.loadEphemeralPeers(ctx)
if e.headPeer != nil {
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
e.cleanup(ctx)
})
}
}
@@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() {
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
@@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.Tracef("add peer to ephemeral list: %s", peer.ID)
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.Errorf("failed to add peer to ephemeral list: %s", err)
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
@@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
e.addPeer(peer.ID, a, newDeadLine())
if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
}
}
func (e *EphemeralManager) loadEphemeralPeers() {
accounts := e.store.GetAllAccounts()
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background())
t := newDeadLine()
count := 0
for _, a := range accounts {
@@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() {
}
}
}
log.Debugf("loaded ephemeral peer(s): %d", count)
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
}
func (e *EphemeralManager) cleanup() {
func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
@@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() {
}
if e.headPeer != nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
} else {
e.timer = nil
}
@@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() {
e.peersLock.Unlock()
for id, p := range deletePeers {
log.Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
if err != nil {
log.Errorf("failed to delete ephemeral peer: %s", err)
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
}
}
}

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"testing"
"time"
@@ -13,11 +14,11 @@ type MockStore struct {
account *Account
}
func (s *MockStore) GetAllAccounts() []*Account {
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) {
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil
@@ -31,7 +32,7 @@ type MocAccountManager struct {
store *MockStore
}
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error {
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
delete(a.store.account.Peers, peerID)
return nil //nolint:nil
}
@@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers()
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup()
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
@@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers()
mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"])
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup()
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
@@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers()
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(v)
mgr.OnPeerConnected(context.Background(), v)
}
mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"])
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup()
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
@@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId("my account", "", "")
store.account = newAccountWithId(context.Background(), "my account", "", "")
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"time"
@@ -11,11 +12,11 @@ import (
)
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
}
events, err := am.eventStore.Get(accountID, 0, 10000, true)
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil {
return nil, err
}
@@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return filtered, nil
}
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
go func() {
_, err := am.eventStore.Save(&activity.Event{
_, err := am.eventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activityID,
InitiatorID: initiatorID,
@@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str
})
if err != nil {
// todo add metric
log.Errorf("received an error while storing an activity event, error: %s", err)
log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err)
}
}()

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"testing"
"time"
@@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
accountID string, count int) {
t.Helper()
for i := 0; i < count; i++ {
_, err := manager.eventStore.Save(&activity.Event{
_, err := manager.eventStore.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: typ,
InitiatorID: initiatorID,
@@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
accountID := "accountID"
t.Run("get empty events list", func(t *testing.T) {
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 0)
_ = manager.eventStore.Close() //nolint
_ = manager.eventStore.Close(context.Background()) //nolint
})
t.Run("get events", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 10)
_ = manager.eventStore.Close() //nolint
_ = manager.eventStore.Close(context.Background()) //nolint
})
t.Run("get events without duplicates", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 1)
_ = manager.eventStore.Close() //nolint
_ = manager.eventStore.Close(context.Background()) //nolint
})
}

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"os"
"path/filepath"
"strings"
@@ -48,8 +49,8 @@ type FileStore struct {
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(filepath.Join(dataDir, storeFileName))
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
if err != nil {
return nil, err
}
@@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
}
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(dataDir, metrics)
func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(ctx, dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(sqlStore.GetInstallationID())
err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
if err != nil {
return nil, err
}
for _, account := range sqlStore.GetAllAccounts() {
for _, account := range sqlStore.GetAllAccounts(ctx) {
store.Accounts[account.Id] = account
}
return store, store.persist(store.storeFile)
return store, store.persist(ctx, store.storeFile)
}
// restore the state of the store from the file.
// Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) {
func restore(ctx context.Context, file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) {
// create a new FileStore if previously didn't exist (e.g. first run)
s := &FileStore{
@@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) {
storeFile: file,
}
err = s.persist(file)
err = s.persist(ctx, file)
if err != nil {
return nil, err
}
@@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) {
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(account, existingLabels)
addPeerLabelsToAccount(ctx, account, existingLabels)
}
// TODO: delete this block after migration
@@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) {
allGroup, err := account.GetGroupAll()
if err != nil {
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
// if the All group didn't exist we probably don't have routes to update
continue
}
@@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) {
}
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(store.storeFile)
err = store.persist(ctx, store.storeFile)
if err != nil {
return nil, err
}
@@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) {
// persist account data to a file
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(file string) error {
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
err := util.WriteJson(file, s)
if err != nil {
@@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error {
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist the FileStore", took.Milliseconds())
log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds())
return nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.Debugf("took %v to acquire global lock", took)
log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
@@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
}
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
@@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
@@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID)
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(ctx, accountID)
}
func (s *FileStore) SaveAccount(account *Account) error {
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
}
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
func (s *FileStore) DeleteAccount(account *Account) error {
func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error {
delete(s.Accounts, account.Id)
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
@@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
// GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
}
// GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
}
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
}
// GetUserByTokenID returns a User object a tokenID belongs to
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
}
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts() (all []*Account) {
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()
defer s.mux.Unlock()
for _, a := range s.Accounts {
@@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) {
}
// GetAccount returns an account for ID
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) {
}
// GetAccountByUser returns a user account
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
}
// GetAccountByPeerID returns an account for a given peer ID
func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
// check Account.Peers for a match
if _, ok := account.Peers[peerID]; !ok {
delete(s.PeerID2AccountID, peerID)
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
return nil, status.NewPeerNotFoundError(peerID)
}
@@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
}
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
}
if stale {
delete(s.PeerKeyID2AccountID, peerKey)
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
return nil, status.NewPeerNotFoundError(peerKey)
}
return account.Copy(), nil
}
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
return accountID, nil
}
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string {
}
// SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(ID string) error {
func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
s.mux.Lock()
defer s.mux.Unlock()
s.InstallationID = ID
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
@@ -732,16 +733,24 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
func (s *FileStore) Close(ctx context.Context) error {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("closing FileStore")
log.WithContext(ctx).Infof("closing FileStore")
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
// GetStoreEngine returns FileStoreEngine
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/sha256"
"net"
"path/filepath"
@@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peerID := "some_peer"
@@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) {
Key: peerKey,
}
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account.DeletePeer(peerID)
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
_, err = store.GetAccountByPeerID(peerID)
_, err = store.GetAccountByPeerID(context.Background(), peerID)
require.Error(t, err, "expecting to get an error when found stale index")
_, err = store.GetAccountByPeerPubKey(peerKey)
_, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
require.Error(t, err, "expecting to get an error when found stale index")
}
func TestNewStore(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
if store.Accounts == nil || len(store.Accounts) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
@@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) {
}
// SaveAccount should trigger persist
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
return
}
@@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
defer store.Close()
defer store.Close(context.Background())
var account *Account
for _, a := range store.Accounts {
@@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) {
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
err = store.DeleteAccount(account)
err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err, "failed to delete account, error: %v", err)
_, ok := store.Accounts[account.Id]
@@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) {
func TestStore(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
@@ -228,12 +229,12 @@ func TestStore(t *testing.T) {
})
// SaveAccount should trigger persist
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
return
}
restored, err := NewFileStore(store.storeFile, nil)
restored, err := NewFileStore(context.Background(), store.storeFile, nil)
if err != nil {
return
}
@@ -281,7 +282,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
@@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
@@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) {
Name: "All",
},
}
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err, "failed to save account")
// restore account with default group with empty Issue field
if store, err = NewFileStore(storeDir, nil); err != nil {
if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
return
}
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
@@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(existingDomain)
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
}
@@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
@@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) {
return
}
account, err := store.GetAccount(expected.Id)
account, err := store.GetAccount(context.Background(), expected.Id)
if err != nil {
t.Fatal(err)
}
@@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
tokenID, err := store.GetTokenIDByHashedToken(hashedToken)
tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
if err != nil {
t.Fatal(err)
}
@@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
@@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:]))
_, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
}
@@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
user, err := store.GetUserByTokenID(tokenID)
user, err := store.GetUserByTokenID(context.Background(), tokenID)
if err != nil {
t.Fatal(err)
}
@@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
wrongTokenID := "someNonExistingTokenID"
_, err = store.GetUserByTokenID(wrongTokenID)
_, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
}
@@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
@@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
}
@@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peer := &nbpeer.Peer{
@@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
assert.Error(t, err)
account.Peers[peer.ID] = peer
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
@@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(account.Id)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID].Location
@@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
func newStore(t *testing.T) *FileStore {
t.Helper()
store, err := NewFileStore(t.TempDir(), nil)
store, err := NewFileStore(context.Background(), t.TempDir(), nil)
if err != nil {
t.Errorf("failed creating a new store")
}

View File

@@ -2,6 +2,7 @@ package geolocation
import (
"bytes"
"context"
"fmt"
"net"
"os"
@@ -52,7 +53,7 @@ type Country struct {
CountryName string
}
func NewGeolocation(dataDir string) (*Geolocation, error) {
func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
if err := loadGeolocationDatabases(dataDir); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
}
@@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
return nil, err
}
locationDB, err := NewSqliteStore(dataDir)
locationDB, err := NewSqliteStore(ctx, dataDir)
if err != nil {
return nil, err
}
@@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
stopCh: make(chan struct{}),
}
go geo.reloader()
go geo.reloader(ctx)
return geo, nil
}
@@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error {
return nil
}
func (gl *Geolocation) reloader() {
func (gl *Geolocation) reloader(ctx context.Context) {
for {
select {
case <-gl.stopCh:
return
case <-time.After(gl.reloadCheckInterval):
if err := gl.locationDB.reload(); err != nil {
log.Errorf("geonames db reload failed: %s", err)
if err := gl.locationDB.reload(ctx); err != nil {
log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
}
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue
}
if !bytes.Equal(gl.sha256sum, newSha256sum1) {
@@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() {
time.Sleep(50 * time.Millisecond)
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue
}
if !bytes.Equal(newSha256sum1, newSha256sum2) {
log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
continue
}
err = gl.reload(newSha256sum2)
err = gl.reload(ctx, newSha256sum2)
if err != nil {
log.Errorf("mmdb reload failed: %s", err)
log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
}
} else {
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
}
}
}
}
func (gl *Geolocation) reload(newSha256sum []byte) error {
func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
gl.mux.Lock()
defer gl.mux.Unlock()
log.Infof("Reloading '%s'", gl.mmdbPath)
log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
err := gl.db.Close()
if err != nil {
@@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error {
gl.db = db
gl.sha256sum = newSha256sum
log.Infof("Successfully reloaded '%s'", gl.mmdbPath)
log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
return nil
}

View File

@@ -2,6 +2,7 @@ package geolocation
import (
"bytes"
"context"
"fmt"
"path/filepath"
"runtime"
@@ -50,10 +51,10 @@ type SqliteStore struct {
sha256sum []byte
}
func NewSqliteStore(dataDir string) (*SqliteStore, error) {
func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
file := filepath.Join(dataDir, GeoSqliteDBFile)
db, err := connectDB(file)
db, err := connectDB(ctx, file)
if err != nil {
return nil, err
}
@@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
}
// reload attempts to reload the SqliteStore's database if the database file has changed.
func (s *SqliteStore) reload() error {
func (s *SqliteStore) reload(ctx context.Context) error {
s.mux.Lock()
defer s.mux.Unlock()
newSha256sum1, err := calculateFileSHA256(s.filePath)
if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
}
if !bytes.Equal(s.sha256sum, newSha256sum1) {
@@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error {
return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
}
log.Infof("Reloading '%s'", s.filePath)
log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
_ = s.close()
s.closed = true
newDb, err := connectDB(s.filePath)
newDb, err := connectDB(ctx, s.filePath)
if err != nil {
return err
}
@@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error {
s.closed = false
s.db = newDb
log.Infof("Successfully reloaded '%s'", s.filePath)
log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
} else {
log.Tracef("No changes in '%s', no need to reload", s.filePath)
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
}
return nil
@@ -168,10 +169,10 @@ func (s *SqliteStore) close() error {
}
// connectDB connects to an SQLite database and prepares it by setting up an in-memory database.
func connectDB(filePath string) (*gorm.DB, error) {
func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
start := time.Now()
defer func() {
log.Debugf("took %v to setup geoname db", time.Since(start))
log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start))
}()
_, err := fileExists(filePath)

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"github.com/rs/xid"
@@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string {
}
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
}
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -108,93 +109,123 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
account, err := am.Store.GetAccount(accountID)
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
var eventsToStore []func()
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
// coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
newGroup.ID = xid.New().String()
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
for _, storeEvent := range eventsToStore {
storeEvent()
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()
// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if exists {
if oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
return nil
return eventsToStore
}
// difference returns the elements in `a` that aren't in `b`.
@@ -213,11 +244,11 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountId)
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
defer unlock()
account, err := am.Store.GetAccount(accountId)
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
@@ -315,23 +346,23 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -345,11 +376,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@@ -371,21 +402,21 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@@ -399,13 +430,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil {
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"errors"
"testing"
@@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedIntegration
err = am.SaveGroup(account.Id, groupAdminUserID, group)
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
}
@@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedJWT
err = am.SaveGroup(account.Id, groupAdminUserID, group)
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
}
@@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedAPI
group.ID = ""
err = am.SaveGroup(account.Id, groupAdminUserID, group)
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
@@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID)
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID)
if err == nil {
t.Errorf("delete %s group successfully", testCase.groupID)
return
@@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(accountID, groupAdminUserID, domain)
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
@@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(account)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(account.Id)
return am.Store.GetAccount(context.Background(), account.Id)
}

View File

@@ -11,12 +11,14 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/management/server/posture"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -40,7 +42,7 @@ type GRPCServer struct {
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -50,6 +52,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
@@ -59,7 +62,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.Debug("unable to use http config to create new jwt middleware")
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
@@ -126,47 +129,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
realIP := getRealIP(srv.Context())
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
ctx := srv.Context()
realIP := getRealIP(ctx)
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(req, syncReq)
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
return err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil {
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
return mapError(err)
return mapError(ctx, err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
}
updates := s.peersUpdateManager.CreateChannel(peer.ID)
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(peer)
s.ephemeralManager.OnPeerConnected(ctx, peer)
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
return s.handleUpdates(peerKey, peer, updates, srv)
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@@ -176,21 +193,21 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
}
if !open {
log.Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer)
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
return nil
}
log.Debugf("received an update for peer %s", peerKey.String())
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
return err
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(peer)
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
return srv.Context().Err()
}
}
@@ -198,10 +215,10 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(peer)
s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@@ -209,37 +226,37 @@ func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(peer)
s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.Debugf("sent an update to peer %s", peerKey.String())
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer)
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims)
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error())
}
@@ -247,7 +264,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(err error) error {
func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case internalStatus.PermissionDenied:
@@ -263,11 +280,11 @@ func mapError(err error) error {
default:
}
}
log.Errorf("got an unhandled error: %s", err)
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil {
return nbpeer.PeerSystemMeta{}
}
@@ -281,7 +298,7 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
continue
}
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
@@ -321,10 +338,10 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
}
}
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
@@ -351,22 +368,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
realIP := getRealIP(ctx)
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(req, loginReq)
peerKey, err := s.parseRequest(ctx, req, loginReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
log.WithContext(ctx).Warn(msg)
return nil, msg
}
userID, err := s.processJwtToken(loginReq, peerKey)
userID, err := s.processJwtToken(ctx, loginReq, peerKey)
if err != nil {
return nil, err
}
@@ -376,33 +403,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq.GetMeta()),
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
})
if err != nil {
log.Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(err)
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(ctx, err)
}
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(s.accountManager, peerKey.String()),
Checks: toProtocolChecks(ctx, postureChecks),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.Warnf("failed encrypting peer %s message", peer.ID)
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
@@ -417,16 +444,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := ""
if loginReq.GetJwtToken() != "" {
var err error
for i := 0; i < 3; i++ {
userID, err = s.validateToken(loginReq.GetJwtToken())
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
if err == nil {
break
}
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
time.Sleep(200 * time.Millisecond)
}
@@ -520,7 +547,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -551,7 +578,7 @@ func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toProtocolChecks(accountManager, peer.Key),
Checks: toProtocolChecks(ctx, checks),
}
}
@@ -561,7 +588,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
@@ -570,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@@ -583,7 +610,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
})
if err != nil {
log.Errorf("failed sending SyncResponse %v", err)
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
return status.Errorf(codes.Internal, "error handling request")
}
@@ -597,14 +624,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
@@ -645,18 +672,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
@@ -692,10 +719,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(req, syncMetaReq)
peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
if err != nil {
return nil, err
}
@@ -703,27 +730,21 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
log.WithContext(ctx).Warn(msg)
return nil, msg
}
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
if err != nil {
return nil, mapError(err)
return nil, mapError(ctx, err)
}
return &proto.Empty{}, nil
}
// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side.
func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
if err != nil {
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
return nil
}
protoChecks := make([]*proto.Checks, 0)
// toProtocolChecks converts posture checks to protocol checks.
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
}
@@ -732,7 +753,7 @@ func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Ch
}
// toProtocolCheck converts a posture.Checks to a proto.Checks.
func toProtocolCheck(postureCheck posture.Checks) *proto.Checks {
func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
protoCheck := &proto.Checks{}
if check := postureCheck.Checks.ProcessCheck; check != nil {

View File

@@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
return
}
resp := toAccountResponse(account)
util.WriteJSONObject(w, []*api.Account{resp})
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(claims)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return
}
@@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(updatedAccount)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteAccount is a HTTP DELETE handler to delete an account
@@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return
}
err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId)
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
func toAccountResponse(account *server.Account) *api.Account {

View File

@@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -22,10 +23,10 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return account, admin, nil
},
UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")

View File

@@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id)
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
}
util.WriteJSONObject(w, apiDNSSettings)
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
}
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups,
}
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}

View File

@@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{
func initDNSSettingsTestData() *DNSSettingsHandler {
return &DNSSettingsHandler{
accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) {
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil
},
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
if dnsSettingsToSave != nil {
return nil
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
},
},

View File

@@ -1,6 +1,7 @@
package http
import (
"context"
"fmt"
"net/http"
@@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id)
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
events := make([]*api.Event, len(accountEvents))
@@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e)
}
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, events)
util.WriteJSONObject(r.Context(), w, events)
}
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
// build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
if err != nil {
log.Errorf("failed to get users from account: %s", err)
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
return err
}
@@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u
if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok {
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
}
}

View File

@@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"io"
"net/http"
@@ -22,13 +23,13 @@ import (
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) {
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
if accountID == account {
return events, nil
}
return []*activity.Event{}, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
},
}, user, nil
},
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
},
},

View File

@@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"io"
"net/http"
@@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
assert.NoError(t, err)
geo, err := geolocation.NewGeolocation(tempDir)
geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
// GetAllCountries retrieves a list of all countries
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
allCountries, err := l.geolocationManager.GetAllCountries()
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
for _, country := range allCountries {
countries = append(countries, toCountryResponse(country))
}
util.WriteJSONObject(w, countries)
util.WriteJSONObject(r.Context(), w, countries)
}
// GetCitiesByCountry retrieves a list of cities based on the given country code
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w)
return
}
if l.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
for _, city := range allCities {
cities = append(cities, toCityResponse(city))
}
util.WriteJSONObject(w, cities)
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(claims)
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
return err
}

View File

@@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
groups, err := h.accountManager.GetAllGroups(account.Id, user.Id)
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
}
util.WriteJSONObject(w, groupsResponse)
util.WriteJSONObject(r.Context(), w, groupsResponse)
}
// UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return
}
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return
}
eg, ok := account.Groups[groupID]
if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
}
@@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: eg.IntegrationReference,
}
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(err, w)
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, &group))
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI,
}
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, &group))
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
}
// DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
aID := account.Id
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
case http.MethodGet:
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id)
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, group))
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default:
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
return
}
}

View File

@@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error {
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set"
}
return nil
},
GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) {
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found")
}
@@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
Issued: nbgroup.GroupIssuedAPI,
}, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
},
}, user, nil
},
DeleteGroupFunc: func(accountID, userId, groupID string) error {
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
return &server.GroupLinkError{
Resource: "something",

View File

@@ -9,6 +9,7 @@ import (
"github.com/rs/cors"
"github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
@@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
corsMiddleware := cors.AllowAll()
claimsExtractor = jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"net/http"
"regexp"
@@ -15,7 +16,7 @@ import (
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
@@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(claims)
user, err := a.getUser(r.Context(), claims)
if err != nil {
log.Errorf("failed to get user from claims: %s", err)
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
@@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
if tokenPathRegexp.MatchString(r.URL.Path) {
log.Debugf("valid Path")
log.WithContext(r.Context()).Debugf("valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
return
}
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -19,16 +20,16 @@ import (
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
type MarkPATUsedFunc func(ctx context.Context, token string) error
// CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
@@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
case "bearer":
err := m.checkJWTFromRequest(w, r, auth)
if err != nil {
log.Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.checkPATFromRequest(w, r, auth)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
validatedToken, err := m.validateAndParseToken(r.Context(), token)
if err != nil {
return err
}
@@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
if err := m.verifyUserAccess(validatedToken); err != nil {
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
return err
}
@@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(authClaims)
return m.checkUserAccessByJWTGroups(ctx, authClaims)
}
// CheckPATFromRequest checks if the PAT is valid
@@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
err = m.markPATUsed(r.Context(), pat.ID)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@@ -15,15 +16,16 @@ import (
)
const (
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
@@ -47,14 +49,14 @@ var testAccount = &server.Account{
},
}
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{
Claims: jwt.MapClaims{
@@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(token string) error {
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
}

View File

@@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
for bypassPath := range bypassPaths {
matched, err := path.Match(bypassPath, requestPath)
if err != nil {
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue
}
if matched {

View File

@@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id)
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
}
util.WriteJSONObject(w, apiNameservers)
util.WriteJSONObject(r.Context(), w, apiNameservers)
}
// CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
@@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
@@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id)
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {

View File

@@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
func initNameserversTestData() *NameserversHandler {
return &NameserversHandler{
accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
@@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler {
SearchDomainsEnabled: searchDomains,
}, nil
},
DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error {
DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error {
return nil
},
SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil
},
},

Some files were not shown because too many files have changed in this diff Show More