Compare commits

...

29 Commits

Author SHA1 Message Date
Viktor Liu
f30995707a Set log level debug 2024-11-27 13:47:03 +01:00
Viktor Liu
5c8bfb7cea Add debug 2024-11-27 13:41:05 +01:00
Zoltan Papp
fc4b37f7bc Exit from processConnResults after all tries (#2621)
* Exit from processConnResults after all tries

If all server is unavailable then the server picker never return
because we never close the result channel.
Count the number of the results and exit when we reached the
expected size
2024-09-19 13:49:28 +02:00
Zoltan Papp
6f0fd1d1b3 - Increase queue size and drop the overflowed messages (#2617)
- Explicit close the net.Conn in user space wgProxy when close the wgProxy
- Add extra logs
2024-09-19 13:49:09 +02:00
Zoltan Papp
28cbb4b70f [client] Cancel the context of wg watcher when the go routine exit (#2612) 2024-09-17 12:10:17 +02:00
Zoltan Papp
1104c9c048 [client] Fix race condition while read/write conn status in peer conn (#2607) 2024-09-17 11:15:14 +02:00
Maycon Santos
5bc601111d [relay] Add health check attempt threshold (#2609)
* Add health check attempt threshold for receiver

* Add health check attempt threshold for sender
2024-09-17 10:04:17 +02:00
Zoltan Papp
b74951f29e [client] Enforce permissions on Win (#2568)
Enforce folder permission on Windows, giving only administrators and system access to the NetBird folder.
2024-09-16 22:42:37 +02:00
Zoltan Papp
97e10e440c Fix leaked server connections (#2596)
Fix leaked server connections

close unused connections in the client lib
close deprecated connection in the server lib
The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes.

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Add logging

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-16 16:11:10 +02:00
pascal-fischer
6c50b0c84b [management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
2024-09-16 15:47:03 +02:00
pascal-fischer
730dd1733e [signal] Fix signal active peers metrics (#2591) 2024-09-15 16:46:55 +02:00
Bethuel Mmbaga
82739e2832 [management] fix legacy decrypting of empty values (#2595)
* allow legacy decrypting on empty values

* validate source size and padding limits

* added tests

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-15 16:22:46 +02:00
Maycon Santos
fa7767e612 Fix get management and signal state race condition (#2570)
* Fix get management and signal state race condition

* fix get full status lock
2024-09-15 16:07:26 +02:00
benniekiss
f1171198de [management] Add command flag to set metrics port for signal and relay service, and update management port (#2599)
* add flags to customize metrics port for relay and signal

* change management default metrics port to match other services
2024-09-14 10:34:32 +02:00
Zoltan Papp
9e041b7f82 Fix blocked net.Conn Close call (#2600) 2024-09-14 10:27:37 +02:00
Zoltan Papp
b4c8cf0a67 Change heartbeat timeout (#2598) 2024-09-14 10:12:54 +02:00
Carlos Hernandez
1ef51a4ffa [client] Ensure engine is stopped before starting it back (#2565)
Before starting a new instance of the engine, check if it is nil and stop the current instance
2024-09-13 16:46:59 +02:00
Maycon Santos
f6d57e7a96 [misc] Support configurable max log size with var NB_LOG_MAX_SIZE_MB (#2592)
* Support configurable max log size with var NB_LOG_MAX_SIZE_MB

* add better logs
2024-09-12 19:56:55 +02:00
Zoltan Papp
ab892b8cf9 Fix wg handshake checking (#2590)
* Fix wg handshake checking

* Ensure in the initial handshake reading

* Change the handshake period
2024-09-12 19:18:02 +02:00
Gianluca Boiano
33c9b2d989 fix: install.sh: avoid call of netbird executable after rpm installation (#2589) 2024-09-12 17:32:47 +02:00
Bethuel Mmbaga
170e842422 [management] Add accessible peers endpoint (#2579)
* move accessible peer to separate endpoint in api doc

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add endpoint to get accessible peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/peers_handler.go

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
2024-09-12 16:19:27 +03:00
Maycon Santos
4c130a0291 Update Go version to 1.23 (#2588) 2024-09-12 13:46:28 +02:00
Maycon Santos
afb9673bc4 [misc] Update core github actions (#2584) 2024-09-11 21:49:05 +02:00
Bethuel Mmbaga
cf6210a6f4 [management] Add GCM encryption and migrate legacy encrypted events (#2569)
* Add AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* migrate legacy encrypted data to AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor and use transaction when migrating data

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add events migration tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip migrating record on error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preallocate capacity for nonce to avoid allocations in Seal

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-11 20:09:57 +03:00
Maycon Santos
c59a39d27d Update service package version (#2582) 2024-09-11 19:05:10 +02:00
Maycon Santos
47adb976f8 Remove pre-release step from workflow (#2583) 2024-09-11 18:59:19 +02:00
Zoltan Papp
9cfc8f8aa4 [relay] change log levels (#2580) 2024-09-11 18:36:19 +02:00
Viktor Liu
2d1bf3982d [relay] Improve relay messages (#2574)
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2024-09-11 16:20:30 +02:00
Viktor Liu
50ebbe482e [client] Don't overwrite allowed IPs when updating the wg peer's endpoint address (#2578)
This will fix broken routes on routing clients when upgrading/downgrading from/to relayed connections.
2024-09-11 16:05:13 +02:00
84 changed files with 3481 additions and 854 deletions

View File

@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
id: go id: go
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Check for duplicate constants - name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: run install script - name: run install script
env: env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@v3 uses: actions/setup-java@v4
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init

View File

@@ -36,18 +36,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21" go-version: "1.23"
cache: false cache: false
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -93,28 +93,28 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
- -
name: upload linux packages name: upload linux packages
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 3 retention-days: 3
- -
name: upload windows packages name: upload windows packages
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 3 retention-days: 3
- -
name: upload macos packages name: upload macos packages
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
path: dist/netbird_darwin** path: dist/netbird_darwin**
@@ -133,17 +133,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21" go-version: "1.23"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -176,7 +176,7 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
@@ -189,18 +189,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21" go-version: "1.23"
cache: false cache: false
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -225,7 +225,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/

View File

@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,10 +219,7 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: handle insisting image # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
- name: run script with Zitadel PostgreSQL - name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -259,9 +256,6 @@ jobs:
docker compose down --volumes --rmi all docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: handle insisting image gen CockroachDB # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
- name: run script with Zitadel CockroachDB - name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env: env:

View File

@@ -117,6 +117,11 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) { if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{} config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil { if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err return nil, err
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = WriteOutConfig(input.ConfigPath, cfg) err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input) return update(input)
} }

View File

@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
} }
defer c.statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { if c.isContextCancelled() {
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock() c.engineMutex.Unlock()
@@ -279,9 +286,10 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil && runningChanOpen {
runningChan <- nil runningChan <- nil
close(runningChan) close(runningChan)
runningChanOpen = false
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
} }
// stop/restore DNS first so dbus and friends don't complain because of a missing interface // stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil { e.stopDNSServer()
e.dnsServer.Stop()
e.dnsServer = nil
}
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop() e.routeManager.Stop()
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
} }
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil { if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err) log.Errorf("Failed to stop engine: %v", err)
} }
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
} log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
} }
func (e *Engine) startNetworkMonitor() { func (e *Engine) startNetworkMonitor() {
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock() defer mu.Unlock()
if debounceTimer != nil { if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop() debounceTimer.Stop()
} }
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine") log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine() e.restartEngine()
}) })
}) })
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil return false, netip.Prefix{}, nil
} }
func (e *Engine) stopDNSServer() {
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
nsGroupStates[i].Enabled = false
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {

View File

@@ -89,8 +89,8 @@ type Conn struct {
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string) onDisconnected func(remotePeer string, wgIP string)
statusRelay ConnStatus statusRelay *AtomicConnStatus
statusICE ConnStatus statusICE *AtomicConnStatus
currentConnPriority ConnPriority currentConnPriority ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection opened bool // this flag is used to prevent close in case of not opened connection
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler, signaler: signaler,
relayManager: relayManager, relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(), allowedIPsIP: allowedIPsIP.String(),
statusRelay: StatusDisconnected, statusRelay: NewAtomicConnStatus(),
statusICE: StatusDisconnected, statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1), iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1),
} }
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
} }
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected { if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} }
} else { } else {
if conn.statusICE == StatusDisconnected { if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
} }
} }
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready") conn.log.Debugf("ICE connection is ready")
conn.statusICE = StatusConnected conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo) defer conn.updateIceState(iceConnInfo)
@@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} }
changed := conn.statusICE != newState && newState != StatusConnecting changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE = newState conn.statusICE.Set(newState)
select { select {
case conn.iCEDisconnected <- changed: case conn.iCEDisconnected <- changed:
@@ -518,11 +518,14 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return return
} }
conn.log.Debugf("Relay connection is ready to use") conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay = StatusConnected conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn) endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
@@ -530,6 +533,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr conn.endpointRelay = endpointUdpAddr
@@ -538,7 +542,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay { if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE == StatusConnected { if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return return
} }
@@ -559,8 +563,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update wg peer configuration: %v", err) conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return return
} }
wgConfigWorkaround()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil {
@@ -594,8 +598,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay = StatusDisconnected conn.statusRelay.Set(StatusDisconnected)
select { select {
case conn.relayDisconnected <- changed: case conn.relayDisconnected <- changed:
@@ -661,8 +665,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
} }
func (conn *Conn) setStatusToDisconnected() { func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay = StatusDisconnected conn.statusRelay.Set(StatusDisconnected)
conn.statusICE = StatusDisconnected conn.statusICE.Set(StatusDisconnected)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -706,7 +710,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
} }
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) { if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
return false return false
} }
@@ -718,11 +722,11 @@ func (conn *Conn) isRelayed() bool {
} }
func (conn *Conn) evalStatus() ConnStatus { func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected { if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
return StatusConnected return StatusConnected
} }
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting { if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting return StatusConnecting
} }
@@ -733,12 +737,12 @@ func (conn *Conn) isConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting { if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
return false return false
} }
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay != StatusConnected { if conn.statusRelay.Get() != StatusConnected {
return false return false
} }
} }
@@ -775,9 +779,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn) ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
err = wgProxy.CloseConn() if errClose := wgProxy.CloseConn(); errClose != nil {
if err != nil { conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
conn.log.Warnf("failed to close turn proxy connection: %v", err)
} }
return nil, nil, err return nil, nil, err
} }

View File

@@ -1,6 +1,10 @@
package peer package peer
import log "github.com/sirupsen/logrus" import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const ( const (
// StatusConnected indicate the peer is in connected state // StatusConnected indicate the peer is in connected state
@@ -12,7 +16,34 @@ const (
) )
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string { func (s ConnStatus) String() string {
switch s { switch s {

View File

@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
for _, table := range tables { for _, table := range tables {
t.Run(table.name, func(t *testing.T) { t.Run(table.name, func(t *testing.T) {
conn.statusICE = table.statusIce si := NewAtomicConnStatus()
conn.statusRelay = table.statusRelay si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status() got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal") assert.Equal(t, got, table.want, "they should be equal")

View File

@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
} }
func (d *Status) GetRosenpassState() RosenpassState { func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{ return RosenpassState{
d.rosenpassEnabled, d.rosenpassEnabled,
d.rosenpassPermissive, d.rosenpassPermissive,
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
} }
func (d *Status) GetManagementState() ManagementState { func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{ return ManagementState{
d.mgmAddress, d.mgmAddress,
d.managementState, d.managementState,
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
} }
func (d *Status) GetSignalState() SignalState { func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{ return SignalState{
d.signalAddress, d.signalAddress,
d.signalState, d.signalState,
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states // GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult { func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil { if d.relayMgr == nil {
return d.relayStates return d.relayStates
} }
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
} }
func (d *Status) GetDNSStates() []NSGroupState { func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
return d.nsGroupStates return d.nsGroupStates
} }
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status // GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus { func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
defer d.mux.Unlock()
fullStatus := FullStatus{ fullStatus := FullStatus{
ManagementState: d.GetManagementState(), ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(), SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(), Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(), RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(), NSGroupStates: d.GetDNSStates(),
} }
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
for _, status := range d.peers { for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status) fullStatus.Peers = append(fullStatus.Peers, status)
} }

View File

@@ -14,7 +14,7 @@ import (
) )
var ( var (
wgHandshakePeriod = 2 * time.Minute wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second wgHandshakeOvertime = 30 * time.Second
) )
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
} }
ctx, ctxCancel := context.WithCancel(ctx) ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx)
w.ctxWgWatch = ctx w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel w.ctxCancelWgWatch = ctxCancel
w.wgStateCheck(ctx, ctxCancel)
} }
func (w *WorkerRelay) DisableWgWatcher() { func (w *WorkerRelay) DisableWgWatcher() {
@@ -157,37 +157,51 @@ func (w *WorkerRelay) CloseConn() {
} }
} }
// wgStateCheck help to check the state of the wireguard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) { func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
timer := time.NewTimer(wgHandshakeOvertime) w.log.Debugf("WireGuard watcher started")
defer timer.Stop() lastHandshake, err := w.wgState()
expected := wgHandshakeOvertime if err != nil {
for { w.log.Warnf("failed to read wg stats: %v", err)
select { lastHandshake = time.Time{}
case <-timer.C: }
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
if time.Since(lastHandshake) > expected { go func(lastHandshake time.Time) {
w.log.Infof("Wireguard handshake timed out, closing relay connection") timer := time.NewTimer(wgHandshakeOvertime)
w.relayLock.Lock() defer timer.Stop()
_ = w.relayedConn.Close() defer ctxCancel()
w.relayLock.Unlock()
w.callBacks.OnDisconnected() for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return return
} }
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
} }
} }(lastHandshake)
} }
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {

View File

@@ -32,8 +32,8 @@ func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
} }
// AddTurnConn start the proxy with the given remote conn // AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn p.remoteConn = remoteConn
var err error var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
@@ -54,6 +54,14 @@ func (p *WGUserSpaceProxy) CloseConn() error {
if p.localConn == nil { if p.localConn == nil {
return nil return nil
} }
if p.remoteConn == nil {
return nil
}
if err := p.remoteConn.Close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
return p.localConn.Close() return p.localConn.Close()
} }
@@ -65,6 +73,8 @@ func (p *WGUserSpaceProxy) Free() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer // proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks // blocks
func (p *WGUserSpaceProxy) proxyToRemote() { func (p *WGUserSpaceProxy) proxyToRemote() {
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
select { select {
@@ -93,7 +103,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard // proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks // blocks
func (p *WGUserSpaceProxy) proxyToLocal() { func (p *WGUserSpaceProxy) proxyToLocal() {
defer p.cancel()
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
select { select {
@@ -103,7 +114,6 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
p.cancel()
return return
} }
log.Errorf("failed to read from remote conn: %s", err) log.Errorf("failed to read from remote conn: %s", err)

View File

@@ -1,13 +1,41 @@
package main package main
import ( import (
"net/http"
_ "net/http/pprof"
"os" "os"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/cmd"
) )
func main() { func main() {
// only if env is not set
if os.Getenv("NB_LOG_LEVEL") == "" {
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
log.Errorf("Failed setting log-level: %v", err)
}
}
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
log.Errorf("Failed setting log-size: %v", err)
}
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
log.Errorf("Failed setting panic log path: %v", err)
}
go startPprofServer()
if err := cmd.Execute(); err != nil { if err := cmd.Execute(); err != nil {
os.Exit(1) os.Exit(1)
} }
} }
func startPprofServer() {
pprofAddr := "localhost:6969"
log.Infof("Starting pprof debugging server on %s", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Infof("pprof server failed: %v", err)
}
}

View File

@@ -110,6 +110,35 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx) ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel s.actCancel = cancel
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
log.Infof("Error getting status: %v", err)
} else if statusResp.FullStatus != nil {
log.Infof("Status --------")
for _, peer := range statusResp.FullStatus.Peers {
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
peer.Fqdn,
peer.IP,
peer.PubKey,
peer.ConnStatus,
peer.Relayed,
peer.RelayAddress,
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
)
}
}
}
}
}()
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin // if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
// on failure we return error to retry // on failure we return error to retry
config, err := internal.UpdateConfig(s.latestConfigInput) config, err := internal.UpdateConfig(s.latestConfigInput)

4
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird module github.com/netbirdio/netbird
go 1.21.0 go 1.23.0
require ( require (
cunicu.li/go-rosenpass v0.4.0 cunicu.li/go-rosenpass v0.4.0
@@ -232,7 +232,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect k8s.io/apimachinery v0.26.2 // indirect
) )
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949

4
go.sum
View File

@@ -523,8 +523,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=

View File

@@ -56,8 +56,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return err return err
} }
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,

View File

@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return err return err
} }
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,

View File

@@ -54,7 +54,7 @@ func Execute() error {
func init() { func init() {
stopCh = make(chan int) stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")

View File

@@ -263,6 +263,11 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
} }
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
type UserPermissions struct { type UserPermissions struct {
DashboardView string `json:"dashboard_view"` DashboardView string `json:"dashboard_view"`
} }
@@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps return grps
} }
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool { func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID) peerGroups := a.getPeerGroups(peerID)
enabled := true enabled := true
@@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList return groupList
} }
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP { func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP var takenIps []net.IP
for _, existingPeer := range a.Peers { for _, existingPeer := range a.Peers {
@@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil return false, nil
} }
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
// addAllGroup to account object if it doesn't exist // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {

View File

@@ -6,13 +6,14 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"fmt" "errors"
) )
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct { type FieldEncrypt struct {
block cipher.Block block cipher.Block
gcm cipher.AEAD
} }
func GenerateKey() (string, error) { func GenerateKey() (string, error) {
@@ -35,14 +36,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{ ec := &FieldEncrypt{
block: block, block: block,
gcm: gcm,
} }
return ec, nil return ec, nil
} }
func (ec *FieldEncrypt) Encrypt(payload string) string { func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload)) plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText)) cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv) cbc := cipher.NewCBCEncrypter(ec.block, iv)
@@ -50,7 +58,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText) return base64.StdEncoding.EncodeToString(cipherText)
} }
func (ec *FieldEncrypt) Decrypt(data string) (string, error) { // Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data) cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil { if err != nil {
return "", err return "", err
@@ -65,17 +88,49 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil return string(payload), nil
} }
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte { func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding) padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...) return append(ciphertext, padText...)
} }
func pkcs5UnPadding(src []byte) ([]byte, error) { func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src) srcLen := len(src)
paddingLen := int(src[srcLen-1]) if srcLen == 0 {
if paddingLen >= srcLen || paddingLen > aes.BlockSize { return nil, errors.New("input data is empty")
return nil, fmt.Errorf("padding size error")
} }
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil return src[:srcLen-paddingLen], nil
} }

View File

@@ -1,6 +1,7 @@
package sqlite package sqlite
import ( import (
"bytes"
"testing" "testing"
) )
@@ -15,7 +16,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
encrypted := ee.Encrypt(testData) encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" { if encrypted == "" {
t.Fatalf("invalid encrypted text") t.Fatalf("invalid encrypted text")
} }
@@ -30,6 +35,32 @@ func TestGenerateKey(t *testing.T) {
} }
} }
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) { func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io" testData := "exampl@netbird.io"
key, err := GenerateKey() key, err := GenerateKey()
@@ -41,7 +72,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
encrypted := ee.Encrypt(testData) encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" { if encrypted == "" {
t.Fatalf("invalid encrypted text") t.Fatalf("invalid encrypted text")
} }
@@ -61,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res) t.Fatalf("incorrect decryption, the result is: %s", res)
} }
} }
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,157 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
if _, err := db.Exec(createTableQuery); err != nil {
return err
}
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
return err
}
if err := updateDeletedUsersTable(ctx, db); err != nil {
return fmt.Errorf("failed to update deleted_users table: %v", err)
}
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
}
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
exists, err := checkColumnExists(db, "deleted_users", "name")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
}
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
}
return nil
}
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
defer func() {
_ = tx.Rollback()
}()
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
if err != nil {
return fmt.Errorf("failed to execute select query: %v", err)
}
defer rows.Close()
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
if err != nil {
return fmt.Errorf("failed to prepare update statement: %v", err)
}
defer updateStmt.Close()
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
return nil
}
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
for rows.Next() {
var (
id, decryptedEmail, decryptedName string
email, name *string
)
err := rows.Scan(&id, &email, &name)
if err != nil {
return err
}
if email != nil {
decryptedEmail, err = crypt.LegacyDecrypt(*email)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt email: %w", err),
)
continue
}
}
if name != nil {
decryptedName, err = crypt.LegacyDecrypt(*name)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt name: %w", err),
)
continue
}
}
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
if err != nil {
return fmt.Errorf("failed to encrypt email: %w", err)
}
encryptedName, err := crypt.Encrypt(decryptedName)
if err != nil {
return fmt.Errorf("failed to encrypt name: %w", err)
}
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
if err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,84 @@
package sqlite
import (
"context"
"database/sql"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/require"
)
func setupDatabase(t *testing.T) *sql.DB {
t.Helper()
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
require.NoError(t, err, "Failed to open database")
t.Cleanup(func() {
_ = db.Close()
})
_, err = db.Exec(createTableQuery)
require.NoError(t, err, "Failed to create events table")
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
require.NoError(t, err, "Failed to create deleted_users table")
return db
}
func TestMigrate(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
legacyName := crypt.LegacyEncrypt("Test Account")
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
require.NoError(t, err, "Failed to insert event")
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
require.NoError(t, err, "Failed to insert legacy encrypted data")
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists")
require.False(t, colExists, "enc_algo column should not exist before migration")
err = migrate(context.Background(), crypt, db)
require.NoError(t, err, "Migration failed")
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
require.True(t, colExists, "enc_algo column should exist after migration")
var encAlgo string
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
require.NoError(t, err, "Failed to select updated data")
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
store, err := createStore(crypt, db)
require.NoError(t, err, "Failed to create store")
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
require.NoError(t, err, "Failed to get events")
require.Len(t, events, 1, "Should have one event")
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
}

View File

@@ -26,7 +26,7 @@ const (
"meta TEXT," + "meta TEXT," +
" target_id TEXT);" " target_id TEXT);"
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);` creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events FROM events
@@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first. and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/ */
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown" fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com" fallbackEmail = "unknown@unknown.com"
gcmEncAlgo = "GCM"
) )
// Store is the implementation of the activity.Store interface backed by SQLite // Store is the implementation of the activity.Store interface backed by SQLite
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err return nil, err
} }
_, err = db.Exec(createTableQuery) if err = migrate(ctx, crypt, db); err != nil {
if err != nil {
_ = db.Close() _ = db.Close()
return nil, err return nil, fmt.Errorf("events database migration: %w", err)
} }
_, err = db.Exec(creatTableDeletedUsersQuery) return createStore(crypt, db)
if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
s := &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
} }
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) { func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil return event.Meta, nil
} }
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) if err != nil {
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName) return nil, err
}
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
if err != nil {
return nil, err
}
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil return nil
} }
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { // createStore initializes and returns a new Store instance with prepared SQL statements.
log.WithContext(ctx).Debugf("check deleted_users table version") func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
rows, err := db.Query(`PRAGMA table_info(deleted_users);`) insertStmt, err := db.Prepare(insertQuery)
if err != nil { if err != nil {
return err _ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
return &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}, nil
}
// checkColumnExists checks if a column exists in a specified table
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
rows, err := db.Query(query)
if err != nil {
return false, fmt.Errorf("failed to query table info: %w", err)
} }
defer rows.Close() defer rows.Close()
found := false
for rows.Next() { for rows.Next() {
var ( var cid int
cid int var name, ctype string
name string var notnull, pk int
dataType string var dfltValue sql.NullString
notNull int
dfltVal sql.NullString err = rows.Scan(&cid, &name, &ctype, &notnull, &dfltValue, &pk)
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
if err != nil { if err != nil {
return err return false, fmt.Errorf("failed to scan row: %w", err)
} }
if name == "name" {
found = true if name == columnName {
break return true, nil
} }
} }
err = rows.Err() if err = rows.Err(); err != nil {
if err != nil { return false, err
return err
} }
if found { return false, nil
return nil
}
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
} }

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
) )
type MockStore struct { type MockStore struct {
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil return s.account, nil
} }
return nil, fmt.Errorf("account not found") return nil, status.NewPeerNotFoundError(peerId)
} }
type MocAccountManager struct { type MocAccountManager struct {

View File

@@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"errors"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -46,6 +48,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"` metrics telemetry.AppMetrics `json:"-"`
} }
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return err
}
account.SetupKeys[setupKeyID].UsedTimes++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}
allGroup.Peers = append(allGroup.Peers, peerID)
return nil
}
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountId)
if err != nil {
return err
}
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
return nil
}
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}
account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}
account.Network.Serial++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return setupKey, nil
}
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps, nil
}
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Network, nil
}
type StoredAccount struct{} type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir // NewFileStore restores a store from the file located in the datadir
@@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") return nil, status.NewSetupKeyNotFoundError()
} }
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
@@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil return account.Users[userID].Copy(), nil
} }
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) { func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID] accountID, ok := s.UserID2AccountID[userID]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) { func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID] account, ok := s.Accounts[accountID]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "account not found") return nil, status.NewAccountNotFoundError(accountID)
} }
return account, nil return account, nil
@@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok { if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") return "", status.NewSetupKeyNotFoundError()
} }
return accountID, nil return accountID, nil
} }
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey) return nil, status.NewPeerNotFoundError(peerKey)
} }
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
} }
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. // SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()

View File

@@ -251,7 +251,7 @@ components:
- name - name
- ssh_enabled - ssh_enabled
- login_expiration_enabled - login_expiration_enabled
PeerBase: Peer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'
- type: object - type: object
@@ -378,25 +378,40 @@ components:
description: User ID of the user that enrolled this peer description: User ID of the user that enrolled this peer
type: string type: string
example: google-oauth2|277474792786460067937 example: google-oauth2|277474792786460067937
os:
description: Peer's operating system and version
type: string
example: linux
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
geoname_id:
description: Unique identifier from the GeoNames database for a specific geographical location.
type: integer
example: 2643743
connected:
description: Peer to Management connection status
type: boolean
example: true
last_seen:
description: Last time peer connected to Netbird's management service
type: string
format: date-time
example: "2023-05-05T10:05:26.420578Z"
required: required:
- ip - ip
- dns_label - dns_label
- user_id - user_id
Peer: - os
allOf: - country_code
- $ref: '#/components/schemas/PeerBase' - city_name
- type: object - geoname_id
properties: - connected
accessible_peers: - last_seen
description: List of accessible peers
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
required:
- accessible_peers
PeerBatch: PeerBatch:
allOf: allOf:
- $ref: '#/components/schemas/PeerBase' - $ref: '#/components/schemas/Peer'
- type: object - type: object
properties: properties:
accessible_peers_count: accessible_peers_count:
@@ -1806,6 +1821,38 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/accessible-peers:
get:
summary: List accessible Peers
description: Returns a list of peers that the specified peer can connect to within the network.
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: A JSON Array of Accessible Peers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup-keys: /api/setup-keys:
get: get:
summary: List all Setup Keys summary: List all Setup Keys

View File

@@ -152,18 +152,36 @@ const (
// AccessiblePeer defines model for AccessiblePeer. // AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct { type AccessiblePeer struct {
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Id Peer ID // Id Peer ID
Id string `json:"id"` Id string `json:"id"`
// Ip Peer's IP address // Ip Peer's IP address
Ip string `json:"ip"` Ip string `json:"ip"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// Name Peer's hostname // Name Peer's hostname
Name string `json:"name"` Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// UserId User ID of the user that enrolled this peer // UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"` UserId string `json:"user_id"`
} }
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer. // Peer defines model for Peer.
type Peer struct { type Peer struct {
// AccessiblePeers List of accessible peers
AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp string `json:"connection_ip"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion string `json:"kernel_version"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
// LoginExpired Indicates whether peer's login expired or not
LoginExpired bool `json:"login_expired"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// SerialNumber System serial number
SerialNumber string `json:"serial_number"`
// SshEnabled Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// UiVersion Peer's desktop UI version
UiVersion string `json:"ui_version"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
// Version Peer's daemon or cli version
Version string `json:"version"`
}
// PeerBase defines model for PeerBase.
type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"` ApprovalRequired bool `json:"approval_required"`

View File

@@ -115,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
} }
func (apiHandler *apiHandler) addUsersEndpoint() { func (apiHandler *apiHandler) addUsersEndpoint() {

View File

@@ -71,12 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return return
} }
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -117,13 +113,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return return
} }
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
} }
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -220,32 +212,66 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
} }
} }
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers { for _, p := range netMap.Peers {
ap := api.AccessiblePeer{ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
} }
for _, p := range netMap.OfflinePeers { for _, p := range netMap.OfflinePeers {
ap := api.AccessiblePeer{ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
} }
return accessiblePeers return accessiblePeers
} }
func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
return api.AccessiblePeer{
CityName: peer.Location.CityName,
Connected: peer.Status.Connected,
CountryCode: peer.Location.CountryCode,
DnsLabel: fqdn(peer, dnsDomain),
GeonameId: int(peer.Location.GeoNameID),
Id: peer.ID,
Ip: peer.IP.String(),
LastSeen: peer.Status.LastSeen,
Name: peer.Name,
Os: peer.Meta.OS,
UserId: peer.UserID,
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
@@ -270,7 +296,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo return groupsInfo
} }
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion osVersion := peer.Meta.OSVersion
if osVersion == "" { if osVersion == "" {
osVersion = peer.Meta.Core osVersion = peer.Meta.Core
@@ -296,7 +322,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled, LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved, ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,

View File

@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
} }
func Test_LoginPerformance(t *testing.T) { func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" { if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping on CI") t.Skip("Skipping test on CI or Windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1}, // {"M", 250, 1},
// {"L", 500, 1}, // {"L", 500, 1},
// {"XL", 750, 1}, // {"XL", 750, 1},
{"XXL", 2000, 1}, {"XXL", 5000, 1},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
} }
defer mgmtServer.GracefulStop() defer mgmtServer.GracefulStop()
t.Logf("management setup complete, start registering peers")
var counter int32 var counter int32
var counterStart int32 var counterStart int32
var wg sync.WaitGroup var wgAccount sync.WaitGroup
var mu sync.Mutex var mu sync.Mutex
messageCalls := []func() error{} messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ { for j := 0; j < bc.accounts; j++ {
wg.Add(1) wgAccount.Add(1)
var wgPeer sync.WaitGroup
go func(j int, counter *int32, counterStart *int32) { go func(j int, counter *int32, counterStart *int32) {
defer wg.Done() defer wgAccount.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j)) account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil { if err != nil {
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
return return
} }
startTime := time.Now()
for i := 0; i < bc.peers; i++ { for i := 0; i < bc.peers; i++ {
wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
t.Logf("failed to generate key: %v", err) t.Logf("failed to generate key: %v", err)
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock() mu.Lock()
messageCalls = append(messageCalls, login) messageCalls = append(messageCalls, login)
mu.Unlock() mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1) go func(peerLogin PeerLogin, counterStart *int32) {
if *counterStart%100 == 0 { defer wgPeer.Done()
t.Logf("registered %d peers", *counterStart) _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
} if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}(peerLogin, counterStart)
} }
wgPeer.Wait()
t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart) }(j, &counter, &counterStart)
} }
wg.Wait() wgAccount.Wait()
t.Logf("prepared %d login calls", len(messageCalls)) t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls) testLoginPerformance(t, messageCalls)

View File

@@ -11,6 +11,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
}() }()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
_, err = account.FindPeerByPubKey(peer.Key) _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil { if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
} }
opEvent := &activity.Event{ opEvent := &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
AccountID: account.Id, AccountID: accountID,
} }
var ephemeral bool var newPeer *nbpeer.Peer
setupKeyName := ""
if !addedByUser {
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return nil, nil, nil, err
}
if !sk.IsValid() { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") var groupsToAdd []string
} var setupKeyID string
var setupKeyName string
account.SetupKeys[sk.Key] = sk.IncrementUsage() var ephemeral bool
opEvent.InitiatorID = sk.Id if addedByUser {
opEvent.Activity = activity.PeerAddedWithSetupKey user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
ephemeral = sk.Ephemeral if err != nil {
setupKeyName = sk.Name return fmt.Errorf("failed to get user groups: %w", err)
} else { }
opEvent.InitiatorID = userID groupsToAdd = user.AutoGroups
opEvent.Activity = activity.PeerAddedByUser opEvent.InitiatorID = userID
} opEvent.Activity = activity.PeerAddedByUser
takenIps := account.getTakenIPs()
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, nil, err
}
registrationTime := time.Now().UTC()
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
Key: peer.Key,
SetupKey: upperKey,
IP: nextIp,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: newLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else { } else {
newPeer.Location.CountryCode = location.Country.ISOCode // Validate the setup key
newPeer.Location.CityName = location.City.Names.En sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
newPeer.Location.GeoNameID = location.City.GeonameID if err != nil {
} return fmt.Errorf("failed to get setup key: %w", err)
} }
// add peer to 'All' group if !sk.IsValid() {
group, err := account.GetGroupAll() return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
if err != nil { }
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
var groupsToAdd []string opEvent.InitiatorID = sk.Id
if addedByUser { opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd, err = account.getUserGroups(userID) groupsToAdd = sk.AutoGroups
if err != nil { ephemeral = sk.Ephemeral
return nil, nil, nil, err setupKeyID = sk.Id
setupKeyName = sk.Name
} }
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, nil, err
}
}
if len(groupsToAdd) > 0 { if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
for _, s := range groupsToAdd { if am.idpManager != nil {
if g, ok := account.Groups[s]; ok && g.Name != "All" { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
g.Peers = append(g.Peers, newPeer.ID) if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
} }
} }
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if addedByUser {
user, err := account.FindUser(userID)
if err != nil { if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") return fmt.Errorf("failed to get free DNS label: %w", err)
} }
user.updateLastLogin(newPeer.LastLogin)
}
account.Peers[newPeer.ID] = newPeer freeIP, err := am.getFreeIP(ctx, transaction, accountID)
account.Network.IncSerial() if err != nil {
err = am.Store.SaveAccount(ctx, account) return fmt.Errorf("failed to get free IP: %w", err)
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
} }
// Account is saved, we can release the lock if newPeer == nil {
unlock() return nil, nil, nil, fmt.Errorf("new peer is nil")
unlock = nil
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
} }
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks := am.getPeerPostureChecks(account, peer) postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
} }
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
}() }()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests. // and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil { if err != nil {
return err return err
} }
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err return err
} }
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin) err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil { if err != nil {
return err return err
} }
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait() wg.Wait()
} }
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}

View File

@@ -7,20 +7,24 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
) )
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks)) assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
} }
func Test_RegisterPeerByUser(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
UserID: existingUserID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
LastLogin: time.Now(),
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
}
func Test_RegisterPeerBySetupKey(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
assert.Equal(t, uint64(0), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -33,6 +34,7 @@ import (
const ( const (
storeSqliteFileName = "store.db" storeSqliteFileName = "store.db"
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found" peerNotFoundFMT = "peer %s not found"
) )
@@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return nil, status.NewSetupKeyNotFoundError()
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
} }
if key.AccountID == "" { if key.AccountID == "" {
@@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil return &user, nil
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User var user User
result := s.db.First(&user, idQueryCondition, userID) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed") return nil, status.NewUserNotFoundError(userID)
} }
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error) return nil, status.NewGetUserFromStoreError()
return nil, status.Errorf(status.Internal, "issue getting user from store")
} }
return &user, nil return &user, nil
@@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found") return nil, status.NewAccountNotFoundError(accountID)
} }
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
@@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User var user User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID) result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
@@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store") return nil, status.Errorf(status.Internal, "issue getting account from store")
} }
@@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
var accountID string var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store") return "", status.Errorf(status.Internal, "issue getting account from store")
} }
@@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
} }
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return "", status.NewSetupKeyNotFoundError()
return "", status.Errorf(status.Internal, "issue getting setup key from store") }
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
return accountID, nil return accountID, nil
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found") return nil, status.Errorf(status.NotFound, "peer not found")
} }
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store") return nil, status.Errorf(status.Internal, "issue getting peer from store")
} }
return &peer, nil return &peer, nil
} }
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found") return nil, status.Errorf(status.NotFound, "settings not found")
} }
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store") return nil, status.Errorf(status.Internal, "issue getting settings from store")
} }
return accountSettings.Settings, nil return accountSettings.Settings, nil
} }
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User var user User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID) return status.NewUserNotFoundError(userID)
} }
return status.Errorf(status.Internal, "issue getting user from store") return status.NewGetUserFromStoreError()
} }
user.LastLogin = lastLogin user.LastLogin = lastLogin
return s.db.Save(user).Error return s.db.Save(&user).Error
} }
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil return store, nil
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}

View File

@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID) require.Equal(t, id, user.PATs[id].ID)
} }
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []net.IP{}, takenIPs)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip1 := net.IP{1, 1, 1, 1}.To16()
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip := net.IP{100, 64, 0, 0}.To16()
assert.Equal(t, ip, network.Net.IP)
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
assert.Equal(t, "", network.Dns)
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
assert.Equal(t, uint64(0), network.Serial)
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

View File

@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error { func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more") return Errorf(PermissionDenied, "peer login has expired, please log in once more")
} }
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}

View File

@@ -27,6 +27,15 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type LockingStrength string
const (
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
)
type Store interface { type Store interface {
GetAllAccounts(ctx context.Context) []*Account GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error)
@@ -41,7 +50,7 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
@@ -60,14 +69,24 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data. // Close should close the store persisting all unsaved data.
Close(ctx context.Context) error Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation. // GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
} }
type StoreEngine string type StoreEngine string

View File

@@ -0,0 +1,120 @@
{
"Accounts": {
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
"CreatedBy": "",
"Domain": "test.com",
"DomainCategory": "private",
"IsDomainPrimaryAccount": true,
"SetupKeys": {
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"Name": "Default key",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["cfefqs706sqkneg59g2g"],
"UsageLimit": 0,
"Ephemeral": false
},
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"Name": "Faulty key with non existing group",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["abcd"],
"UsageLimit": 0,
"Ephemeral": false
}
},
"Network": {
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": "",
"Serial": 0
},
"Peers": {},
"Users": {
"edafee4e-63fb-11ec-90d6-0242ac120003": {
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "admin",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": ["cfefqs706sqkneg59g3g"],
"PATs": {},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
},
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "user",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": null,
"PATs": {
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
"UserID": "",
"Name": "",
"HashedToken": "SoMeHaShEdToKeN",
"ExpirationDate": "2023-02-27T00:00:00Z",
"CreatedBy": "user",
"CreatedAt": "2023-01-01T00:00:00Z",
"LastUsed": "2023-02-01T00:00:00Z"
}
},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
}
},
"Groups": {
"cfefqs706sqkneg59g4g": {
"ID": "cfefqs706sqkneg59g4g",
"Name": "All",
"Peers": []
},
"cfefqs706sqkneg59g3g": {
"ID": "cfefqs706sqkneg59g3g",
"Name": "AwesomeGroup1",
"Peers": []
},
"cfefqs706sqkneg59g2g": {
"ID": "cfefqs706sqkneg59g2g",
"Name": "AwesomeGroup2",
"Peers": []
}
},
"Rules": null,
"Policies": [],
"Routes": null,
"NameServerGroups": null,
"DNSSettings": null,
"Settings": {
"PeerLoginExpirationEnabled": false,
"PeerLoginExpiration": 86400000000000,
"GroupsPropagationEnabled": false,
"JWTGroupsEnabled": false,
"JWTGroupsClaimName": ""
}
}
},
"InstallationID": ""
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac" auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
const defaultDuration = 12 * time.Hour const defaultDuration = 12 * time.Hour
@@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct {
turnCfg *TURNConfig turnCfg *TURNConfig
relayCfg *Relay relayCfg *Relay
turnHmacToken *auth.TimedHMAC turnHmacToken *auth.TimedHMAC
relayHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
turnCancelMap map[string]chan struct{} turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{}
@@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
duration = defaultDuration duration = defaultDuration
} }
mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration) hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
var err error
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
log.Errorf("failed to create relay token generator: %s", err)
}
} }
return mgr return mgr
@@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
} }
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate TURN token: %s", err) return nil, fmt.Errorf("generate TURN token: %s", err)
} }
return (*Token)(turnToken), nil return (*Token)(turnToken), nil
} }
@@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
if m.relayHmacToken == nil { if m.relayHmacToken == nil {
return nil, fmt.Errorf("relay configuration is not set") return nil, fmt.Errorf("relay configuration is not set")
} }
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New) relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate relay token: %s", err) return nil, fmt.Errorf("generate relay token: %s", err)
} }
return (*Token)(relayToken), nil
return &Token{
Payload: string(relayToken.Payload),
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
}, nil
} }
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) { func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
@@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee
} }
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) { func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New) relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil { if err != nil {
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err) log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
return return
@@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe
WiretrusteeConfig: &proto.WiretrusteeConfig{ WiretrusteeConfig: &proto.WiretrusteeConfig{
Relay: &proto.RelayConfig{ Relay: &proto.RelayConfig{
Urls: m.relayCfg.Addresses, Urls: m.relayCfg.Addresses,
TokenPayload: relayToken.Payload, TokenPayload: string(relayToken.Payload),
TokenSignature: relayToken.Signature, TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
}, },
// omit Turns to avoid updates there // omit Turns to avoid updates there
}, },

View File

@@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Errorf("expected generated relay signature not to be empty, got empty") t.Errorf("expected generated relay signature not to be empty, got empty")
} }
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret)) hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
} }
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {

View File

@@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
} }
func (u *User) updateLastLogin(login time.Time) {
u.LastLogin = login
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise // HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool { func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
@@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin) newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err) log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
} }

View File

@@ -1,12 +1,14 @@
package allow package allow
import "hash"
// Auth is a Validator that allows all connections. // Auth is a Validator that allows all connections.
// Used this for testing purposes only. // Used this for testing purposes only.
type Auth struct { type Auth struct {
} }
func (a *Auth) Validate(func() hash.Hash, any) error { func (a *Auth) Validate(any) error {
return nil
}
func (a *Auth) ValidateHelloMsgType(any) error {
return nil return nil
} }

View File

@@ -1,9 +1,11 @@
package hmac package hmac
import ( import (
"encoding/base64"
"fmt"
"sync" "sync"
log "github.com/sirupsen/logrus" v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
// TokenStore is a simple in-memory store for token // TokenStore is a simple in-memory store for token
@@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
return nil return nil
} }
t, err := marshalToken(*token) sig, err := base64.StdEncoding.DecodeString(token.Signature)
if err != nil { if err != nil {
log.Debugf("failed to marshal token: %s", err) return fmt.Errorf("decode signature: %w", err)
return err
} }
a.token = t
tok := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: sig,
Payload: []byte(token.Payload),
}
a.token = tok.Marshal()
return nil return nil
} }

View File

@@ -18,17 +18,6 @@ type Token struct {
Signature string Signature string
} }
func marshalToken(token Token) ([]byte, error) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(token)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return nil, fmt.Errorf("failed to marshal token: %w", err)
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) { func unmarshalToken(payload []byte) (Token, error) {
var creds Token var creds Token
buffer := bytes.NewBuffer(payload) buffer := bytes.NewBuffer(payload)

View File

@@ -0,0 +1,40 @@
package v2
import (
"crypto/sha256"
"hash"
)
const (
AuthAlgoUnknown AuthAlgo = iota
AuthAlgoHMACSHA256
)
type AuthAlgo uint8
func (a AuthAlgo) String() string {
switch a {
case AuthAlgoHMACSHA256:
return "HMAC-SHA256"
default:
return "Unknown"
}
}
func (a AuthAlgo) New() func() hash.Hash {
switch a {
case AuthAlgoHMACSHA256:
return sha256.New
default:
return nil
}
}
func (a AuthAlgo) Size() int {
switch a {
case AuthAlgoHMACSHA256:
return sha256.Size
default:
return 0
}
}

View File

@@ -0,0 +1,45 @@
package v2
import (
"crypto/hmac"
"fmt"
"hash"
"strconv"
"time"
)
type Generator struct {
algo func() hash.Hash
algoType AuthAlgo
secret []byte
timeToLive time.Duration
}
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
algoFunc := algo.New()
if algoFunc == nil {
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
}
return &Generator{
algo: algoFunc,
algoType: algo,
secret: secret,
timeToLive: timeToLive,
}, nil
}
func (g *Generator) GenerateToken() (*Token, error) {
expirationTime := time.Now().Add(g.timeToLive).Unix()
payload := []byte(strconv.FormatInt(expirationTime, 10))
h := hmac.New(g.algo, g.secret)
h.Write(payload)
signature := h.Sum(nil)
return &Token{
AuthAlgo: g.algoType,
Signature: signature,
Payload: payload,
}, nil
}

View File

@@ -0,0 +1,110 @@
package v2
import (
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(token.Payload) == 0 {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
timeToLive := -1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,39 @@
package v2
import "errors"
type Token struct {
AuthAlgo AuthAlgo
Signature []byte
Payload []byte
}
func (t *Token) Marshal() []byte {
size := 1 + len(t.Signature) + len(t.Payload)
buf := make([]byte, size)
buf[0] = byte(t.AuthAlgo)
copy(buf[1:], t.Signature)
copy(buf[1+len(t.Signature):], t.Payload)
return buf
}
func UnmarshalToken(data []byte) (*Token, error) {
if len(data) == 0 {
return nil, errors.New("invalid token data")
}
algo := AuthAlgo(data[0])
sigSize := algo.Size()
if len(data) < 1+sigSize {
return nil, errors.New("invalid token data: insufficient length")
}
return &Token{
AuthAlgo: algo,
Signature: data[1 : 1+sigSize],
Payload: data[1+sigSize:],
}, nil
}

View File

@@ -0,0 +1,59 @@
package v2
import (
"crypto/hmac"
"errors"
"fmt"
"strconv"
"time"
)
const minLengthUnixTimestamp = 10
type Validator struct {
secret []byte
}
func NewValidator(secret []byte) *Validator {
return &Validator{secret: secret}
}
func (v *Validator) Validate(data any) error {
d, ok := data.([]byte)
if !ok {
return fmt.Errorf("invalid data type")
}
token, err := UnmarshalToken(d)
if err != nil {
return fmt.Errorf("unmarshal token: %w", err)
}
if len(token.Payload) < minLengthUnixTimestamp {
return errors.New("invalid payload: insufficient length")
}
hashFunc := token.AuthAlgo.New()
if hashFunc == nil {
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
}
h := hmac.New(hashFunc, v.secret)
h.Write(token.Payload)
expectedMAC := h.Sum(nil)
if !hmac.Equal(token.Signature, expectedMAC) {
return errors.New("invalid signature")
}
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timestamp {
return fmt.Errorf("expired token")
}
return nil
}

View File

@@ -1,8 +1,8 @@
package hmac package hmac
import ( import (
"crypto/sha256"
"fmt" "fmt"
"hash"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
} }
} }
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error { func (a *TimedHMACValidator) Validate(credentials any) error {
b, ok := credentials.([]byte) b, ok := credentials.([]byte)
if !ok { if !ok {
return fmt.Errorf("invalid credentials type") return fmt.Errorf("invalid credentials type")
@@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
log.Debugf("failed to unmarshal token: %s", err) log.Debugf("failed to unmarshal token: %s", err)
return err return err
} }
return a.TimedHMAC.Validate(algo, c) return a.TimedHMAC.Validate(sha256.New, c)
} }

View File

@@ -1,8 +1,35 @@
package auth package auth
import "hash" import (
"time"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// Validator is an interface that defines the Validate method. // Validator is an interface that defines the Validate method.
type Validator interface { type Validator interface {
Validate(func() hash.Hash, any) error Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator
}
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
return &TimedHMACValidator{
authenticatorV2: authv2.NewValidator(secret),
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
return a.authenticatorV2.Validate(credentials)
}
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
return a.authenticator.Validate(credentials)
} }

View File

@@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
) )
const ( const (
@@ -60,37 +58,65 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr) m.bufPool.Put(m.bufPtr)
} }
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct { type connContainer struct {
log *log.Entry
conn *Conn conn *Conn
messages chan Msg messages chan Msg
msgChanLock sync.Mutex msgChanLock sync.Mutex
closed bool // flag to check if channel is closed closed bool // flag to check if channel is closed
ctx context.Context
cancel context.CancelFunc
} }
func newConnContainer(conn *Conn, messages chan Msg) *connContainer { func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
return &connContainer{ return &connContainer{
log: log,
conn: conn, conn: conn,
messages: messages, messages: messages,
ctx: ctx,
cancel: cancel,
} }
} }
func (cc *connContainer) writeMsg(msg Msg) { func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock() cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock() defer cc.msgChanLock.Unlock()
if cc.closed { if cc.closed {
msg.Free()
return return
} }
cc.messages <- msg
select {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
} }
func (cc *connContainer) close() { func (cc *connContainer) close() {
cc.cancel()
cc.msgChanLock.Lock() cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock() defer cc.msgChanLock.Unlock()
if cc.closed { if cc.closed {
return return
} }
close(cc.messages)
cc.closed = true cc.closed = true
close(cc.messages)
for msg := range cc.messages {
msg.Free()
}
} }
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and // Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@@ -122,8 +148,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ c := &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithFields(log.Fields{"relay": serverURL}),
parentCtx: ctx, parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
@@ -136,11 +162,13 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
}, },
conns: make(map[string]*connContainer), conns: make(map[string]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL) c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@@ -161,7 +189,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(c.relayConn)
c.log.Infof("relay connection established with: %s", c.connectionURL) c.log.Infof("relay connection established")
return nil return nil
} }
@@ -183,11 +211,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
return nil, ErrConnAlreadyExists return nil, ErrConnAlreadyExists
} }
log.Infof("open connection to peer: %s", hashedStringID) c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2) msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel) c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil return conn, nil
} }
@@ -231,7 +259,7 @@ func (c *Client) connect() error {
if err != nil { if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
} }
return err return err
} }
@@ -240,31 +268,21 @@ func (c *Client) connect() error {
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
authMsg := &auth2.Msg{ msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("marshal auth message: %w", err) c.log.Errorf("failed to marshal auth message: %s", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err return err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to send hello message: %s", err) c.log.Errorf("failed to send auth message: %s", err)
return err return err
} }
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(buf)
if err != nil { if err != nil {
log.Errorf("failed to read hello response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return err
} }
@@ -275,34 +293,29 @@ func (c *Client) handShake() error {
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil { if err != nil {
log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return err
} }
if msgType != messages.MsgTypeHelloResponse { if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType) c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return fmt.Errorf("unexpected message type")
} }
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n]) addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil { if err != nil {
return err return err
} }
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock() c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL} c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock() c.muInstanceURL.Unlock()
return nil return nil
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag() internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver() hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag) go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var ( var (
@@ -314,6 +327,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
buf := *bufPtr buf := *bufPtr
n, errExit = relayConn.Read(buf) n, errExit = relayConn.Read(buf)
if errExit != nil { if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock() c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() { if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit) c.log.Debugf("failed to read message from relay server: %s", errExit)
@@ -360,7 +374,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose: case messages.MsgTypeClose:
log.Debugf("relay connection close by server") c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return false return false
} }
@@ -429,14 +443,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
// todo: use buffer pool instead of create new transport msg. // todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload) msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil { if err != nil {
log.Errorf("failed to marshal transport message: %s", err) c.log.Errorf("failed to marshal transport message: %s", err)
return 0, err return 0, err
} }
// the write always return with 0 length because the underling does not support the size feedback. // the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to write transport message: %s", err) c.log.Errorf("failed to write transport message: %s", err)
} }
return len(payload), err return len(payload), err
} }
@@ -450,12 +464,15 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
} }
c.log.Errorf("health check timeout") c.log.Errorf("health check timeout")
internalStopFlag.set() internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
return return
case <-c.parentCtx.Done(): case <-c.parentCtx.Done():
err := c.close(true) err := c.close(true)
if err != nil { if err != nil {
log.Errorf("failed to teardown connection: %s", err) c.log.Errorf("failed to teardown connection: %s", err)
} }
return return
} }
@@ -481,8 +498,9 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
container.close() c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close()
return nil return nil
} }
@@ -495,10 +513,12 @@ func (c *Client) close(gracefullyExit bool) error {
var err error var err error
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
c.log.Warn("relay connection was already marked as not running")
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
c.log.Infof("closing all peer connections")
c.closeAllConns() c.closeAllConns()
if gracefullyExit { if gracefullyExit {
c.writeCloseMsg() c.writeCloseMsg()
@@ -506,8 +526,9 @@ func (c *Client) close(gracefullyExit bool) error {
err = c.relayConn.Close() err = c.relayConn.Close()
c.mu.Unlock() c.mu.Unlock()
c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait() c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL) c.log.Infof("relay connection closed")
return err return err
} }

View File

@@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
} }
} }
func TestCloseNotDrainedChannel(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
// the internal channel buffer size is 2. So we should overflow it
for i := 0; i < 5; i++ {
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
}
// wait for delivery
time.Sleep(1 * time.Second)
err = connBobToAlice.Close()
if err != nil {
t.Errorf("failed to close channel: %s", err)
}
}
func waitForServerToStart(errChan chan error) error { func waitForServerToStart(errChan chan error) error {
select { select {
case err := <-errChan: case err := <-errChan:

View File

@@ -3,7 +3,6 @@ package client
import ( import (
"container/list" "container/list"
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
@@ -17,8 +16,6 @@ import (
var ( var (
relayCleanupInterval = 60 * time.Second relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
) )
@@ -92,67 +89,23 @@ func (m *Manager) Serve() error {
} }
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
totalServers := len(m.serverURLs) sp := ServerPicker{
TokenStore: m.tokenStore,
successChan := make(chan *Client, 1) PeerID: m.peerID,
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
} }
var errCount int client, err := sp.PickServer(m.ctx, m.serverURLs)
if err != nil {
for { return err
select {
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
} }
} m.relayClient = client
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) { m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
// TODO: abort the connection if another connection was successful m.relayClient.SetOnDisconnectListener(func() {
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID) m.onServerDisconnected(client.connectionURL)
if err := relayClient.Connect(); err != nil { })
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err) m.startCleanupLoop()
return return nil
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
} }
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be

98
relay/client/picker.go Normal file
View File

@@ -0,0 +1,98 @@
package client
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
type connResult struct {
RelayClient *Client
Url string
Err error
}
type ServerPicker struct {
TokenStore *auth.TokenStore
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(urls)
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
go sp.processConnResults(connResultChan, successChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect()
resultChan <- connResult{
RelayClient: relayClient,
Url: url,
Err: err,
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
if hasSuccess {
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
if err := cr.RelayClient.Close(); err != nil {
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
}
continue
}
hasSuccess = true
successChan <- cr
}
close(successChan)
}

View File

@@ -0,0 +1,31 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@@ -16,21 +17,18 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
auth "github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
const (
metricsPort = 9090
)
type Config struct { type Config struct {
ListenAddress string ListenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port // it is a domain:port or ip:port
ExposedAddress string ExposedAddress string
MetricsPort int
LetsencryptEmail string LetsencryptEmail string
LetsencryptDataDir string LetsencryptDataDir string
LetsencryptDomains []string LetsencryptDomains []string
@@ -79,6 +77,7 @@ func init() {
cobraConfig = &Config{} cobraConfig = &Config{}
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address") rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers") rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.") rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration") rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
@@ -115,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err) return fmt.Errorf("failed to initialize log: %s", err)
} }
metricsServer, err := metrics.NewServer(metricsPort, "") metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil { if err != nil {
log.Debugf("setup metrics: %v", err) log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err) return fmt.Errorf("setup metrics: %v", err)
@@ -139,7 +138,9 @@ func execute(cmd *cobra.Command, args []string) error {
} }
srvListenerCfg.TLSConfig = tlsConfig srvListenerCfg.TLSConfig = tlsConfig
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) log.Debugf("failed to create relay server: %v", err)

View File

@@ -3,10 +3,12 @@ package healthcheck
import ( import (
"context" "context"
"time" "time"
log "github.com/sirupsen/logrus"
) )
var ( var (
heartbeatTimeout = healthCheckInterval + 3*time.Second heartbeatTimeout = healthCheckInterval + 10*time.Second
) )
// Receiver is a healthcheck receiver // Receiver is a healthcheck receiver
@@ -14,23 +16,26 @@ var (
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval // The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct { type Receiver struct {
OnTimeout chan struct{} OnTimeout chan struct{}
log *log.Entry
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
heartbeat chan struct{} heartbeat chan struct{}
alive bool alive bool
attemptThreshold int
} }
// NewReceiver creates a new healthcheck receiver and start the timer in the background // NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver() *Receiver { func NewReceiver(log *log.Entry) *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background()) ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{ r := &Receiver{
OnTimeout: make(chan struct{}, 1), OnTimeout: make(chan struct{}, 1),
ctx: ctx, log: log,
ctxCancel: ctxCancel, ctx: ctx,
heartbeat: make(chan struct{}, 1), ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
} }
go r.waitForHealthcheck() go r.waitForHealthcheck()
@@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() {
defer r.ctxCancel() defer r.ctxCancel()
defer close(r.OnTimeout) defer close(r.OnTimeout)
failureCounter := 0
for { for {
select { select {
case <-r.heartbeat: case <-r.heartbeat:
r.alive = true r.alive = true
failureCounter = 0
case <-ticker.C: case <-ticker.C:
if r.alive { if r.alive {
r.alive = false r.alive = false
continue continue
} }
failureCounter++
if failureCounter < r.attemptThreshold {
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
continue
}
r.notifyTimeout() r.notifyTimeout()
return return
case <-r.ctx.Done(): case <-r.ctx.Done():

View File

@@ -1,13 +1,18 @@
package healthcheck package healthcheck
import ( import (
"context"
"fmt"
"os"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
) )
func TestNewReceiver(t *testing.T) { func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second heartbeatTimeout = 5 * time.Second
r := NewReceiver() r := NewReceiver(log.WithContext(context.Background()))
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
@@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) {
func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second heartbeatTimeout = 1 * time.Second
r := NewReceiver() r := NewReceiver(log.WithContext(context.Background()))
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
@@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
func TestNewReceiverAck(t *testing.T) { func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second heartbeatTimeout = 2 * time.Second
r := NewReceiver() r := NewReceiver(log.WithContext(context.Background()))
r.Heartbeat() r.Heartbeat()
@@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) {
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
} }
} }
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
defer func() {
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
case <-receiver.OnTimeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -2,12 +2,21 @@ package healthcheck
import ( import (
"context" "context"
"os"
"strconv"
"time" "time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
) )
var ( var (
healthCheckInterval = 25 * time.Second healthCheckInterval = 25 * time.Second
healthCheckTimeout = 5 * time.Second healthCheckTimeout = 20 * time.Second
) )
// Sender is a healthcheck sender // Sender is a healthcheck sender
@@ -15,20 +24,25 @@ var (
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled // It will also stop if the context is canceled
type Sender struct { type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer // HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{} HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time // Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{} Timeout chan struct{}
ack chan struct{} ack chan struct{}
alive bool
attemptThreshold int
} }
// NewSender creates a new healthcheck sender // NewSender creates a new healthcheck sender
func NewSender() *Sender { func NewSender(log *log.Entry) *Sender {
hc := &Sender{ hc := &Sender{
HealthCheck: make(chan struct{}, 1), log: log,
Timeout: make(chan struct{}, 1), HealthCheck: make(chan struct{}, 1),
ack: make(chan struct{}, 1), Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
} }
return hc return hc
@@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval) ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop() defer ticker.Stop()
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout) timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTimer.Stop() defer timeoutTicker.Stop()
defer close(hc.HealthCheck) defer close(hc.HealthCheck)
defer close(hc.Timeout) defer close(hc.Timeout)
failureCounter := 0
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
hc.HealthCheck <- struct{}{} hc.HealthCheck <- struct{}{}
case <-timeoutTimer.C: case <-timeoutTicker.C:
if hc.alive {
hc.alive = false
continue
}
failureCounter++
if failureCounter < hc.attemptThreshold {
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
continue
}
hc.Timeout <- struct{}{} hc.Timeout <- struct{}{}
return return
case <-hc.ack: case <-hc.ack:
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout) failureCounter = 0
hc.alive = true
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,9 +2,12 @@ package healthcheck
import ( import (
"context" "context"
"fmt"
"os" "os"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
func TestNewHealthPeriod(t *testing.T) { func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := NewSender() hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
iterations := 0 iterations := 0
@@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) { func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := NewSender() hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
select { select {
@@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) {
func TestNewHealthcheckStop(t *testing.T) { func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
hc := NewSender() hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) { func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := NewSender() hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
iterations := 0 iterations := 0
@@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) {
t.Fatalf("is not exited") t.Fatalf("is not exited")
} }
} }
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
defer func() {
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx)
go func() {
responded := false
for {
select {
case <-ctx.Done():
return
case _, ok := <-sender.HealthCheck:
if !ok {
return
}
if tc.resetCounterOnce && !responded {
responded = true
sender.OnHCResponse()
}
}
}
}()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address package address
import ( import (
@@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth package auth
import ( import (
@@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte AdditionalData []byte
} }
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) { func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg var msg *Msg

View File

@@ -7,12 +7,21 @@ import (
) )
const ( const (
MsgTypeUnknown MsgType = 0 MaxHandshakeSize = 212
MsgTypeHello MsgType = 1 MaxHandshakeRespSize = 8192
CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4 MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck MsgType = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1 SizeOfVersionByte = 1
SizeOfMsgType = 1 SizeOfMsgType = 1
@@ -22,12 +31,12 @@ const (
sizeOfMagicByte = 4 sizeOfMagicByte = 4
headerSizeTransport = IDSize headerSizeTransport = IDSize
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
MaxHandshakeSize = 8192 headerSizeAuth = sizeOfMagicByte + IDSize
headerSizeAuthResp = 0
CurrentProtocolVersion = 1
) )
var ( var (
@@ -47,6 +56,10 @@ func (m MsgType) String() string {
return "hello" return "hello"
case MsgTypeHelloResponse: case MsgTypeHelloResponse:
return "hello response" return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport: case MsgTypeTransport:
return "transport" return "transport"
case MsgTypeClose: case MsgTypeClose:
@@ -58,10 +71,6 @@ func (m MsgType) String() string {
} }
} }
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol // ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) { func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte { if len(msg) < SizeOfVersionByte {
@@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case case
MsgTypeHello, MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck:
@@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case case
MsgTypeHelloResponse, MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck:
@@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
} }
} }
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message // MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This // The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
@@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message. // MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's // In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
@@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message. // UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) { func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp { if len(msg) < headerSizeHelloResp {
@@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
msg = append(msg, ab...)
if len(msg) > MaxHandshakeRespSize {
return nil, fmt.Errorf("invalid message length: %d", len(msg))
}
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
}
// MarshalCloseMsg creates a close message. // MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the // The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection. // client can send this message. After receiving this message, the server or client will close the connection.

View File

@@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) {
} }
} }
func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalAuthMsg(peerID, []byte{})
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")

View File

@@ -103,7 +103,7 @@ func (m *Metrics) PeerActivity(peerID string) {
select { select {
case m.peerActivityChan <- peerID: case m.peerActivityChan <- peerID:
default: default:
log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID) log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID)
} }
} }

View File

@@ -49,7 +49,7 @@ func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := healthcheck.NewSender() hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
go p.handleHealthcheckEvents(ctx, hc) go p.handleHealthcheckEvents(ctx, hc)
@@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
// connection. // connection.
func (p *Peer) CloseGracefully(ctx context.Context) { func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock() p.connMu.Lock()
defer p.connMu.Unlock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil { if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
@@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
if err != nil { if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf("failed to close connection to peer: %s", err)
} }
}
func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock() defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
} }
// String returns the peer ID // String returns the peer ID
@@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
if err != nil { if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf("failed to close connection to peer: %s", err)
} }
p.log.Info("peer connection closed due healthcheck timeout")
return return
case <-ctx.Done(): case <-ctx.Done():
return return
@@ -184,7 +193,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
stringPeerID := messages.HashIDToString(peerID) stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID) dp, ok := p.store.Peer(stringPeerID)
if !ok { if !ok {
p.log.Errorf("peer not found: %s", stringPeerID) p.log.Debugf("peer not found: %s", stringPeerID)
return return
} }

View File

@@ -2,7 +2,6 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@@ -14,7 +13,9 @@ import (
"github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address" "github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth" authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
) )
@@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
} }
if msgType != messages.MsgTypeHello { var (
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr()) responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
} }
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
if err != nil { if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err) return nil, err
} }
authMsg, err := authmsg.UnmarshalMsg(authData) _, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("unmarshal auth message: %w", err)
}
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
msg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
_, err = conn.Write(msg)
if err != nil { if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
} }
return peerID, nil return peerID, nil
} }
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}

View File

@@ -19,10 +19,14 @@ func NewStore() *Store {
} }
// AddPeer adds a peer to the store // AddPeer adds a peer to the store
// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) { func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()]
if ok {
odlPeer.Close()
}
s.peers[peer.String()] = peer s.peers[peer.String()] = peer
} }

View File

@@ -2,13 +2,57 @@ package server
import ( import (
"context" "context"
"net"
"testing" "testing"
"time"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
) )
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) { func TestStore_DeletePeer(t *testing.T) {
s := NewStore() s := NewStore()
@@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p1 := NewPeer(m, []byte("peer_id"), nil, nil) conn := &mockConn{}
p2 := NewPeer(m, []byte("peer_id"), nil, nil) p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1) s.AddPeer(p1)
s.AddPeer(p2) s.AddPeer(p2)

View File

@@ -300,11 +300,13 @@ install_netbird() {
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service # Load and start netbird service
if ! ${SUDO} netbird service install 2>&1; then if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
echo "NetBird service has already been loaded" if ! ${SUDO} netbird service install 2>&1; then
fi echo "NetBird service has already been loaded"
if ! ${SUDO} netbird service start 2>&1; then fi
echo "NetBird service has already been started" if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
fi
fi fi

View File

@@ -29,12 +29,9 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
) )
const (
metricsPort = 9090
)
var ( var (
signalPort int signalPort int
metricsPort int
signalLetsencryptDomain string signalLetsencryptDomain string
signalSSLDir string signalSSLDir string
defaultSignalSSLDir string defaultSignalSSLDir string
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() { func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")

View File

@@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID) peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer) registry.Peers.Store(peer.Id, peer)
return
} }
log.Debugf("peer registered [%s]", peer.Id) log.Debugf("peer registered [%s]", peer.Id)
registry.metrics.ActivePeers.Add(context.Background(), 1)
// record time as milliseconds // record time as milliseconds
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
@@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
peer.Id, pp.StreamID, peer.StreamID) peer.Id, pp.StreamID, peer.StreamID)
return return
} }
registry.metrics.ActivePeers.Add(context.Background(), -1)
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
} }
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
} }

View File

@@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
s.registry.Register(p) s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
s.metrics.ActivePeers.Add(stream.Context(), 1)
return p, nil return p, nil
} else { } else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
@@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
s.registry.Deregister(p) s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
s.metrics.ActivePeers.Add(context.Background(), -1)
} }
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {

View File

@@ -10,51 +10,30 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// WriteJson writes JSON config object to a file creating parent directories if required // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
// The output JSON is pretty-formatted func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
func WriteJson(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
} }
// make it pretty err = EnforcePermission(file)
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil { if err != nil {
return err return err
} }
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) return writeJson(file, obj, configDir, configFileName)
}
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
} }
tempFileName := tempFile.Name() return writeJson(file, obj, configDir, configFileName)
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
} }
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil return nil
} }
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
}
func openOrCreateFile(file string) (*os.File, error) { func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file) s, err := os.Stat(file)
if err == nil { if err == nil {
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
} }
err := os.MkdirAll(configDir, 0750) err := os.MkdirAll(configDir, 0750)
if err != nil {
return "", "", err
}
return configDir, configFileName, err return configDir, configFileName, err
} }

View File

@@ -5,6 +5,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@@ -12,6 +13,8 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
const defaultLogSize = 5
// InitLog parses and sets log-level input // InitLog parses and sets log-level input
func InitLog(logLevel string, logPath string) error { func InitLog(logLevel string, logPath string) error {
level, err := log.ParseLevel(logLevel) level, err := log.ParseLevel(logLevel)
@@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err) log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err return err
} }
customOutputs := []string{"console", "syslog"}; customOutputs := []string{"console", "syslog"}
if logPath != "" && !slices.Contains(customOutputs, logPath) { if logPath != "" && !slices.Contains(customOutputs, logPath) {
maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{ lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic // Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath), Filename: filepath.ToSlash(logPath),
MaxSize: 5, // MB MaxSize: maxLogSize, // MB
MaxBackups: 10, MaxBackups: 10,
MaxAge: 30, // days MaxAge: 30, // days
Compress: true, Compress: true,
@@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error {
log.SetLevel(level) log.SetLevel(level)
return nil return nil
} }
func getLogMaxSize() int {
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
size, err := strconv.ParseInt(sizeVar, 10, 64)
if err != nil {
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
return defaultLogSize
}
log.Infof("Setting log file max size to %d MB", size)
return int(size)
}
return defaultLogSize
}

7
util/permission.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build !windows
package util
func EnforcePermission(dirPath string) error {
return nil
}

View File

@@ -0,0 +1,86 @@
package util
import (
"path/filepath"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
securityFlags = windows.OWNER_SECURITY_INFORMATION |
windows.GROUP_SECURITY_INFORMATION |
windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION
)
func EnforcePermission(file string) error {
dirPath := filepath.Dir(file)
user, group, err := sids()
if err != nil {
return err
}
adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return err
}
explicitAccess := []windows.EXPLICIT_ACCESS{
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_USER,
TrusteeValue: windows.TrusteeValueFromSID(user),
},
},
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
},
},
}
dacl, err := windows.ACLFromEntries(explicitAccess, nil)
if err != nil {
return err
}
return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
}
func sids() (*windows.SID, *windows.SID, error) {
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
if err != nil {
return nil, nil, err
}
defer func() {
if err := token.Close(); err != nil {
log.Errorf("failed to close process token: %v", err)
}
}()
tu, err := token.GetTokenUser()
if err != nil {
return nil, nil, err
}
pg, err := token.GetTokenPrimaryGroup()
if err != nil {
return nil, nil, err
}
return tu.User.Sid, pg.PrimaryGroup, nil
}