Compare commits

...

12 Commits

Author SHA1 Message Date
Maycon Santos
f6d57e7a96 [misc] Support configurable max log size with var NB_LOG_MAX_SIZE_MB (#2592)
* Support configurable max log size with var NB_LOG_MAX_SIZE_MB

* add better logs
2024-09-12 19:56:55 +02:00
Zoltan Papp
ab892b8cf9 Fix wg handshake checking (#2590)
* Fix wg handshake checking

* Ensure in the initial handshake reading

* Change the handshake period
2024-09-12 19:18:02 +02:00
Gianluca Boiano
33c9b2d989 fix: install.sh: avoid call of netbird executable after rpm installation (#2589) 2024-09-12 17:32:47 +02:00
Bethuel Mmbaga
170e842422 [management] Add accessible peers endpoint (#2579)
* move accessible peer to separate endpoint in api doc

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add endpoint to get accessible peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/peers_handler.go

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
2024-09-12 16:19:27 +03:00
Maycon Santos
4c130a0291 Update Go version to 1.23 (#2588) 2024-09-12 13:46:28 +02:00
Maycon Santos
afb9673bc4 [misc] Update core github actions (#2584) 2024-09-11 21:49:05 +02:00
Bethuel Mmbaga
cf6210a6f4 [management] Add GCM encryption and migrate legacy encrypted events (#2569)
* Add AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* migrate legacy encrypted data to AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor and use transaction when migrating data

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add events migration tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip migrating record on error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preallocate capacity for nonce to avoid allocations in Seal

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-11 20:09:57 +03:00
Maycon Santos
c59a39d27d Update service package version (#2582) 2024-09-11 19:05:10 +02:00
Maycon Santos
47adb976f8 Remove pre-release step from workflow (#2583) 2024-09-11 18:59:19 +02:00
Zoltan Papp
9cfc8f8aa4 [relay] change log levels (#2580) 2024-09-11 18:36:19 +02:00
Viktor Liu
2d1bf3982d [relay] Improve relay messages (#2574)
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2024-09-11 16:20:30 +02:00
Viktor Liu
50ebbe482e [client] Don't overwrite allowed IPs when updating the wg peer's endpoint address (#2578)
This will fix broken routes on routing clients when upgrading/downgrading from/to relayed connections.
2024-09-11 16:05:13 +02:00
46 changed files with 1218 additions and 426 deletions

View File

@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -55,12 +55,12 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
id: go id: go
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Check for duplicate constants - name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: run install script - name: run install script
env: env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@v3 uses: actions/setup-java@v4
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init

View File

@@ -36,18 +36,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21" go-version: "1.23"
cache: false cache: false
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -93,28 +93,28 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
- -
name: upload linux packages name: upload linux packages
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 3 retention-days: 3
- -
name: upload windows packages name: upload windows packages
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 3 retention-days: 3
- -
name: upload macos packages name: upload macos packages
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
path: dist/netbird_darwin** path: dist/netbird_darwin**
@@ -133,17 +133,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21" go-version: "1.23"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -176,7 +176,7 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
@@ -189,18 +189,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
name: Checkout name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21" go-version: "1.23"
cache: false cache: false
- -
name: Cache Go modules name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -225,7 +225,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
name: upload non tags for debug purposes name: upload non tags for debug purposes
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/

View File

@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version: "1.21.x" go-version: "1.23.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,10 +219,7 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: handle insisting image # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
- name: run script with Zitadel PostgreSQL - name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -259,9 +256,6 @@ jobs:
docker compose down --volumes --rmi all docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: handle insisting image gen CockroachDB # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
- name: run script with Zitadel CockroachDB - name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env: env:

View File

@@ -484,11 +484,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
// switch back to relay connection // switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.log.Debugf("ICE disconnected, set Relay to active connection")
conn.workerRelay.EnableWgWatcher(conn.ctx)
err := conn.configureWGEndpoint(conn.endpointRelay) err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil { if err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} }
@@ -551,6 +551,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
} }
} }
conn.workerRelay.EnableWgWatcher(conn.ctx)
err = conn.configureWGEndpoint(endpointUdpAddr) err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil { if err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@@ -560,7 +561,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
conn.workerRelay.EnableWgWatcher(conn.ctx)
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil {

View File

@@ -14,7 +14,7 @@ import (
) )
var ( var (
wgHandshakePeriod = 2 * time.Minute wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second wgHandshakeOvertime = 30 * time.Second
) )
@@ -109,7 +109,7 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
} }
ctx, ctxCancel := context.WithCancel(ctx) ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx) w.wgStateCheck(ctx)
w.ctxWgWatch = ctx w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel w.ctxCancelWgWatch = ctxCancel
@@ -157,37 +157,50 @@ func (w *WorkerRelay) CloseConn() {
} }
} }
// wgStateCheck help to check the state of the wireguard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) { func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
timer := time.NewTimer(wgHandshakeOvertime) lastHandshake, err := w.wgState()
defer timer.Stop() if err != nil {
expected := wgHandshakeOvertime w.log.Errorf("failed to read wg stats: %v", err)
for { lastHandshake = time.Time{}
select { }
case <-timer.C:
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
if time.Since(lastHandshake) > expected { go func(lastHandshake time.Time) {
w.log.Infof("Wireguard handshake timed out, closing relay connection") timer := time.NewTimer(wgHandshakeOvertime)
w.relayLock.Lock() defer timer.Stop()
_ = w.relayedConn.Close()
w.relayLock.Unlock() for {
w.callBacks.OnDisconnected() select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return return
} }
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
} }
} }(lastHandshake)
} }
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {

4
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird module github.com/netbirdio/netbird
go 1.21.0 go 1.23.0
require ( require (
cunicu.li/go-rosenpass v0.4.0 cunicu.li/go-rosenpass v0.4.0
@@ -232,7 +232,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect k8s.io/apimachinery v0.26.2 // indirect
) )
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949

4
go.sum
View File

@@ -523,8 +523,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=

View File

@@ -56,8 +56,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return err return err
} }
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,

View File

@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return err return err
} }
peer := wgtypes.PeerConfig{ peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,

View File

@@ -6,6 +6,7 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
) )
@@ -13,6 +14,7 @@ var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct { type FieldEncrypt struct {
block cipher.Block block cipher.Block
gcm cipher.AEAD
} }
func GenerateKey() (string, error) { func GenerateKey() (string, error) {
@@ -35,14 +37,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{ ec := &FieldEncrypt{
block: block, block: block,
gcm: gcm,
} }
return ec, nil return ec, nil
} }
func (ec *FieldEncrypt) Encrypt(payload string) string { func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload)) plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText)) cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv) cbc := cipher.NewCBCEncrypter(ec.block, iv)
@@ -50,7 +59,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText) return base64.StdEncoding.EncodeToString(cipherText)
} }
func (ec *FieldEncrypt) Decrypt(data string) (string, error) { // Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data) cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil { if err != nil {
return "", err return "", err
@@ -65,6 +89,27 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil return string(payload), nil
} }
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte { func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding) padText := bytes.Repeat([]byte{byte(padding)}, padding)

View File

@@ -15,7 +15,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
encrypted := ee.Encrypt(testData) encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" { if encrypted == "" {
t.Fatalf("invalid encrypted text") t.Fatalf("invalid encrypted text")
} }
@@ -30,6 +34,32 @@ func TestGenerateKey(t *testing.T) {
} }
} }
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) { func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io" testData := "exampl@netbird.io"
key, err := GenerateKey() key, err := GenerateKey()
@@ -41,7 +71,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
encrypted := ee.Encrypt(testData) encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" { if encrypted == "" {
t.Fatalf("invalid encrypted text") t.Fatalf("invalid encrypted text")
} }

View File

@@ -0,0 +1,157 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
if _, err := db.Exec(createTableQuery); err != nil {
return err
}
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
return err
}
if err := updateDeletedUsersTable(ctx, db); err != nil {
return fmt.Errorf("failed to update deleted_users table: %v", err)
}
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
}
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
exists, err := checkColumnExists(db, "deleted_users", "name")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
}
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
}
return nil
}
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
defer func() {
_ = tx.Rollback()
}()
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
if err != nil {
return fmt.Errorf("failed to execute select query: %v", err)
}
defer rows.Close()
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
if err != nil {
return fmt.Errorf("failed to prepare update statement: %v", err)
}
defer updateStmt.Close()
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
return nil
}
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
for rows.Next() {
var (
id, decryptedEmail, decryptedName string
email, name *string
)
err := rows.Scan(&id, &email, &name)
if err != nil {
return err
}
if email != nil {
decryptedEmail, err = crypt.LegacyDecrypt(*email)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt email: %w", err),
)
continue
}
}
if name != nil {
decryptedName, err = crypt.LegacyDecrypt(*name)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt name: %w", err),
)
continue
}
}
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
if err != nil {
return fmt.Errorf("failed to encrypt email: %w", err)
}
encryptedName, err := crypt.Encrypt(decryptedName)
if err != nil {
return fmt.Errorf("failed to encrypt name: %w", err)
}
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
if err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,84 @@
package sqlite
import (
"context"
"database/sql"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/require"
)
func setupDatabase(t *testing.T) *sql.DB {
t.Helper()
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
require.NoError(t, err, "Failed to open database")
t.Cleanup(func() {
_ = db.Close()
})
_, err = db.Exec(createTableQuery)
require.NoError(t, err, "Failed to create events table")
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
require.NoError(t, err, "Failed to create deleted_users table")
return db
}
func TestMigrate(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
legacyName := crypt.LegacyEncrypt("Test Account")
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
require.NoError(t, err, "Failed to insert event")
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
require.NoError(t, err, "Failed to insert legacy encrypted data")
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists")
require.False(t, colExists, "enc_algo column should not exist before migration")
err = migrate(context.Background(), crypt, db)
require.NoError(t, err, "Migration failed")
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
require.True(t, colExists, "enc_algo column should exist after migration")
var encAlgo string
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
require.NoError(t, err, "Failed to select updated data")
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
store, err := createStore(crypt, db)
require.NoError(t, err, "Failed to create store")
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
require.NoError(t, err, "Failed to get events")
require.Len(t, events, 1, "Should have one event")
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
}

View File

@@ -26,7 +26,7 @@ const (
"meta TEXT," + "meta TEXT," +
" target_id TEXT);" " target_id TEXT);"
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);` creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events FROM events
@@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first. and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/ */
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown" fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com" fallbackEmail = "unknown@unknown.com"
gcmEncAlgo = "GCM"
) )
// Store is the implementation of the activity.Store interface backed by SQLite // Store is the implementation of the activity.Store interface backed by SQLite
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err return nil, err
} }
_, err = db.Exec(createTableQuery) if err = migrate(ctx, crypt, db); err != nil {
if err != nil {
_ = db.Close() _ = db.Close()
return nil, err return nil, fmt.Errorf("events database migration: %w", err)
} }
_, err = db.Exec(creatTableDeletedUsersQuery) return createStore(crypt, db)
if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
s := &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
} }
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) { func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil return event.Meta, nil
} }
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) if err != nil {
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName) return nil, err
}
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
if err != nil {
return nil, err
}
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil return nil
} }
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { // createStore initializes and returns a new Store instance with prepared SQL statements.
log.WithContext(ctx).Debugf("check deleted_users table version") func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
rows, err := db.Query(`PRAGMA table_info(deleted_users);`) insertStmt, err := db.Prepare(insertQuery)
if err != nil { if err != nil {
return err _ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
return &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}, nil
}
// checkColumnExists checks if a column exists in a specified table
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
rows, err := db.Query(query)
if err != nil {
return false, fmt.Errorf("failed to query table info: %w", err)
} }
defer rows.Close() defer rows.Close()
found := false
for rows.Next() { for rows.Next() {
var ( var cid int
cid int var name, ctype string
name string var notnull, pk int
dataType string var dfltValue sql.NullString
notNull int
dfltVal sql.NullString err = rows.Scan(&cid, &name, &ctype, &notnull, &dfltValue, &pk)
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
if err != nil { if err != nil {
return err return false, fmt.Errorf("failed to scan row: %w", err)
} }
if name == "name" {
found = true if name == columnName {
break return true, nil
} }
} }
err = rows.Err() if err = rows.Err(); err != nil {
if err != nil { return false, err
return err
} }
if found { return false, nil
return nil
}
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
} }

View File

@@ -251,7 +251,7 @@ components:
- name - name
- ssh_enabled - ssh_enabled
- login_expiration_enabled - login_expiration_enabled
PeerBase: Peer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'
- type: object - type: object
@@ -378,25 +378,40 @@ components:
description: User ID of the user that enrolled this peer description: User ID of the user that enrolled this peer
type: string type: string
example: google-oauth2|277474792786460067937 example: google-oauth2|277474792786460067937
os:
description: Peer's operating system and version
type: string
example: linux
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
geoname_id:
description: Unique identifier from the GeoNames database for a specific geographical location.
type: integer
example: 2643743
connected:
description: Peer to Management connection status
type: boolean
example: true
last_seen:
description: Last time peer connected to Netbird's management service
type: string
format: date-time
example: "2023-05-05T10:05:26.420578Z"
required: required:
- ip - ip
- dns_label - dns_label
- user_id - user_id
Peer: - os
allOf: - country_code
- $ref: '#/components/schemas/PeerBase' - city_name
- type: object - geoname_id
properties: - connected
accessible_peers: - last_seen
description: List of accessible peers
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
required:
- accessible_peers
PeerBatch: PeerBatch:
allOf: allOf:
- $ref: '#/components/schemas/PeerBase' - $ref: '#/components/schemas/Peer'
- type: object - type: object
properties: properties:
accessible_peers_count: accessible_peers_count:
@@ -1806,6 +1821,38 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/accessible-peers:
get:
summary: List accessible Peers
description: Returns a list of peers that the specified peer can connect to within the network.
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: A JSON Array of Accessible Peers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup-keys: /api/setup-keys:
get: get:
summary: List all Setup Keys summary: List all Setup Keys

View File

@@ -152,18 +152,36 @@ const (
// AccessiblePeer defines model for AccessiblePeer. // AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct { type AccessiblePeer struct {
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Id Peer ID // Id Peer ID
Id string `json:"id"` Id string `json:"id"`
// Ip Peer's IP address // Ip Peer's IP address
Ip string `json:"ip"` Ip string `json:"ip"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// Name Peer's hostname // Name Peer's hostname
Name string `json:"name"` Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// UserId User ID of the user that enrolled this peer // UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"` UserId string `json:"user_id"`
} }
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer. // Peer defines model for Peer.
type Peer struct { type Peer struct {
// AccessiblePeers List of accessible peers
AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp string `json:"connection_ip"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion string `json:"kernel_version"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
// LoginExpired Indicates whether peer's login expired or not
LoginExpired bool `json:"login_expired"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// SerialNumber System serial number
SerialNumber string `json:"serial_number"`
// SshEnabled Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// UiVersion Peer's desktop UI version
UiVersion string `json:"ui_version"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
// Version Peer's daemon or cli version
Version string `json:"version"`
}
// PeerBase defines model for PeerBase.
type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"` ApprovalRequired bool `json:"approval_required"`

View File

@@ -115,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
} }
func (apiHandler *apiHandler) addUsersEndpoint() { func (apiHandler *apiHandler) addUsersEndpoint() {

View File

@@ -71,12 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return return
} }
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -117,13 +113,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return return
} }
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
} }
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -220,32 +212,66 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
} }
} }
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers { for _, p := range netMap.Peers {
ap := api.AccessiblePeer{ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
} }
for _, p := range netMap.OfflinePeers { for _, p := range netMap.OfflinePeers {
ap := api.AccessiblePeer{ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
} }
return accessiblePeers return accessiblePeers
} }
func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
return api.AccessiblePeer{
CityName: peer.Location.CityName,
Connected: peer.Status.Connected,
CountryCode: peer.Location.CountryCode,
DnsLabel: fqdn(peer, dnsDomain),
GeonameId: int(peer.Location.GeoNameID),
Id: peer.ID,
Ip: peer.IP.String(),
LastSeen: peer.Status.LastSeen,
Name: peer.Name,
Os: peer.Meta.OS,
UserId: peer.UserID,
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
@@ -270,7 +296,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo return groupsInfo
} }
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion osVersion := peer.Meta.OSVersion
if osVersion == "" { if osVersion == "" {
osVersion = peer.Meta.Core osVersion = peer.Meta.Core
@@ -296,7 +322,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled, LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved, ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac" auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
const defaultDuration = 12 * time.Hour const defaultDuration = 12 * time.Hour
@@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct {
turnCfg *TURNConfig turnCfg *TURNConfig
relayCfg *Relay relayCfg *Relay
turnHmacToken *auth.TimedHMAC turnHmacToken *auth.TimedHMAC
relayHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
turnCancelMap map[string]chan struct{} turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{}
@@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
duration = defaultDuration duration = defaultDuration
} }
mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration) hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
var err error
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
log.Errorf("failed to create relay token generator: %s", err)
}
} }
return mgr return mgr
@@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
} }
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate TURN token: %s", err) return nil, fmt.Errorf("generate TURN token: %s", err)
} }
return (*Token)(turnToken), nil return (*Token)(turnToken), nil
} }
@@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
if m.relayHmacToken == nil { if m.relayHmacToken == nil {
return nil, fmt.Errorf("relay configuration is not set") return nil, fmt.Errorf("relay configuration is not set")
} }
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New) relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate relay token: %s", err) return nil, fmt.Errorf("generate relay token: %s", err)
} }
return (*Token)(relayToken), nil
return &Token{
Payload: string(relayToken.Payload),
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
}, nil
} }
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) { func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
@@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee
} }
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) { func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New) relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil { if err != nil {
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err) log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
return return
@@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe
WiretrusteeConfig: &proto.WiretrusteeConfig{ WiretrusteeConfig: &proto.WiretrusteeConfig{
Relay: &proto.RelayConfig{ Relay: &proto.RelayConfig{
Urls: m.relayCfg.Addresses, Urls: m.relayCfg.Addresses,
TokenPayload: relayToken.Payload, TokenPayload: string(relayToken.Payload),
TokenSignature: relayToken.Signature, TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
}, },
// omit Turns to avoid updates there // omit Turns to avoid updates there
}, },

View File

@@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Errorf("expected generated relay signature not to be empty, got empty") t.Errorf("expected generated relay signature not to be empty, got empty")
} }
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret)) hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
} }
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {

View File

@@ -1,12 +1,14 @@
package allow package allow
import "hash"
// Auth is a Validator that allows all connections. // Auth is a Validator that allows all connections.
// Used this for testing purposes only. // Used this for testing purposes only.
type Auth struct { type Auth struct {
} }
func (a *Auth) Validate(func() hash.Hash, any) error { func (a *Auth) Validate(any) error {
return nil
}
func (a *Auth) ValidateHelloMsgType(any) error {
return nil return nil
} }

View File

@@ -1,9 +1,11 @@
package hmac package hmac
import ( import (
"encoding/base64"
"fmt"
"sync" "sync"
log "github.com/sirupsen/logrus" v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
// TokenStore is a simple in-memory store for token // TokenStore is a simple in-memory store for token
@@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
return nil return nil
} }
t, err := marshalToken(*token) sig, err := base64.StdEncoding.DecodeString(token.Signature)
if err != nil { if err != nil {
log.Debugf("failed to marshal token: %s", err) return fmt.Errorf("decode signature: %w", err)
return err
} }
a.token = t
tok := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: sig,
Payload: []byte(token.Payload),
}
a.token = tok.Marshal()
return nil return nil
} }

View File

@@ -18,17 +18,6 @@ type Token struct {
Signature string Signature string
} }
func marshalToken(token Token) ([]byte, error) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(token)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return nil, fmt.Errorf("failed to marshal token: %w", err)
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) { func unmarshalToken(payload []byte) (Token, error) {
var creds Token var creds Token
buffer := bytes.NewBuffer(payload) buffer := bytes.NewBuffer(payload)

View File

@@ -0,0 +1,40 @@
package v2
import (
"crypto/sha256"
"hash"
)
const (
AuthAlgoUnknown AuthAlgo = iota
AuthAlgoHMACSHA256
)
type AuthAlgo uint8
func (a AuthAlgo) String() string {
switch a {
case AuthAlgoHMACSHA256:
return "HMAC-SHA256"
default:
return "Unknown"
}
}
func (a AuthAlgo) New() func() hash.Hash {
switch a {
case AuthAlgoHMACSHA256:
return sha256.New
default:
return nil
}
}
func (a AuthAlgo) Size() int {
switch a {
case AuthAlgoHMACSHA256:
return sha256.Size
default:
return 0
}
}

View File

@@ -0,0 +1,45 @@
package v2
import (
"crypto/hmac"
"fmt"
"hash"
"strconv"
"time"
)
type Generator struct {
algo func() hash.Hash
algoType AuthAlgo
secret []byte
timeToLive time.Duration
}
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
algoFunc := algo.New()
if algoFunc == nil {
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
}
return &Generator{
algo: algoFunc,
algoType: algo,
secret: secret,
timeToLive: timeToLive,
}, nil
}
func (g *Generator) GenerateToken() (*Token, error) {
expirationTime := time.Now().Add(g.timeToLive).Unix()
payload := []byte(strconv.FormatInt(expirationTime, 10))
h := hmac.New(g.algo, g.secret)
h.Write(payload)
signature := h.Sum(nil)
return &Token{
AuthAlgo: g.algoType,
Signature: signature,
Payload: payload,
}, nil
}

View File

@@ -0,0 +1,110 @@
package v2
import (
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(token.Payload) == 0 {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
timeToLive := -1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,39 @@
package v2
import "errors"
type Token struct {
AuthAlgo AuthAlgo
Signature []byte
Payload []byte
}
func (t *Token) Marshal() []byte {
size := 1 + len(t.Signature) + len(t.Payload)
buf := make([]byte, size)
buf[0] = byte(t.AuthAlgo)
copy(buf[1:], t.Signature)
copy(buf[1+len(t.Signature):], t.Payload)
return buf
}
func UnmarshalToken(data []byte) (*Token, error) {
if len(data) == 0 {
return nil, errors.New("invalid token data")
}
algo := AuthAlgo(data[0])
sigSize := algo.Size()
if len(data) < 1+sigSize {
return nil, errors.New("invalid token data: insufficient length")
}
return &Token{
AuthAlgo: algo,
Signature: data[1 : 1+sigSize],
Payload: data[1+sigSize:],
}, nil
}

View File

@@ -0,0 +1,59 @@
package v2
import (
"crypto/hmac"
"errors"
"fmt"
"strconv"
"time"
)
const minLengthUnixTimestamp = 10
type Validator struct {
secret []byte
}
func NewValidator(secret []byte) *Validator {
return &Validator{secret: secret}
}
func (v *Validator) Validate(data any) error {
d, ok := data.([]byte)
if !ok {
return fmt.Errorf("invalid data type")
}
token, err := UnmarshalToken(d)
if err != nil {
return fmt.Errorf("unmarshal token: %w", err)
}
if len(token.Payload) < minLengthUnixTimestamp {
return errors.New("invalid payload: insufficient length")
}
hashFunc := token.AuthAlgo.New()
if hashFunc == nil {
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
}
h := hmac.New(hashFunc, v.secret)
h.Write(token.Payload)
expectedMAC := h.Sum(nil)
if !hmac.Equal(token.Signature, expectedMAC) {
return errors.New("invalid signature")
}
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timestamp {
return fmt.Errorf("expired token")
}
return nil
}

View File

@@ -1,8 +1,8 @@
package hmac package hmac
import ( import (
"crypto/sha256"
"fmt" "fmt"
"hash"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
} }
} }
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error { func (a *TimedHMACValidator) Validate(credentials any) error {
b, ok := credentials.([]byte) b, ok := credentials.([]byte)
if !ok { if !ok {
return fmt.Errorf("invalid credentials type") return fmt.Errorf("invalid credentials type")
@@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
log.Debugf("failed to unmarshal token: %s", err) log.Debugf("failed to unmarshal token: %s", err)
return err return err
} }
return a.TimedHMAC.Validate(algo, c) return a.TimedHMAC.Validate(sha256.New, c)
} }

View File

@@ -1,8 +1,35 @@
package auth package auth
import "hash" import (
"time"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// Validator is an interface that defines the Validate method. // Validator is an interface that defines the Validate method.
type Validator interface { type Validator interface {
Validate(func() hash.Hash, any) error Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator
}
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
return &TimedHMACValidator{
authenticatorV2: authv2.NewValidator(secret),
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
return a.authenticatorV2.Validate(credentials)
}
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
return a.authenticator.Validate(credentials)
} }

View File

@@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
) )
const ( const (
@@ -240,31 +238,21 @@ func (c *Client) connect() error {
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
authMsg := &auth2.Msg{ msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("marshal auth message: %w", err) log.Errorf("failed to marshal auth message: %s", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err return err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to send hello message: %s", err) log.Errorf("failed to send auth message: %s", err)
return err return err
} }
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(buf)
if err != nil { if err != nil {
log.Errorf("failed to read hello response: %s", err) log.Errorf("failed to read auth response: %s", err)
return err return err
} }
@@ -279,23 +267,18 @@ func (c *Client) handShake() error {
return err return err
} }
if msgType != messages.MsgTypeHelloResponse { if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType) log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return fmt.Errorf("unexpected message type")
} }
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n]) addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil { if err != nil {
return err return err
} }
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock() c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL} c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock() c.muInstanceURL.Unlock()
return nil return nil
} }

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@@ -16,7 +17,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
auth "github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -139,7 +140,9 @@ func execute(cmd *cobra.Command, args []string) error {
} }
srvListenerCfg.TLSConfig = tlsConfig srvListenerCfg.TLSConfig = tlsConfig
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) log.Debugf("failed to create relay server: %v", err)

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address package address
import ( import (
@@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth package auth
import ( import (
@@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte AdditionalData []byte
} }
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) { func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg var msg *Msg

View File

@@ -7,12 +7,21 @@ import (
) )
const ( const (
MsgTypeUnknown MsgType = 0 MaxHandshakeSize = 212
MsgTypeHello MsgType = 1 MaxHandshakeRespSize = 8192
CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4 MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck MsgType = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1 SizeOfVersionByte = 1
SizeOfMsgType = 1 SizeOfMsgType = 1
@@ -22,12 +31,12 @@ const (
sizeOfMagicByte = 4 sizeOfMagicByte = 4
headerSizeTransport = IDSize headerSizeTransport = IDSize
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
MaxHandshakeSize = 8192 headerSizeAuth = sizeOfMagicByte + IDSize
headerSizeAuthResp = 0
CurrentProtocolVersion = 1
) )
var ( var (
@@ -47,6 +56,10 @@ func (m MsgType) String() string {
return "hello" return "hello"
case MsgTypeHelloResponse: case MsgTypeHelloResponse:
return "hello response" return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport: case MsgTypeTransport:
return "transport" return "transport"
case MsgTypeClose: case MsgTypeClose:
@@ -58,10 +71,6 @@ func (m MsgType) String() string {
} }
} }
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol // ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) { func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte { if len(msg) < SizeOfVersionByte {
@@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case case
MsgTypeHello, MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck:
@@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case case
MsgTypeHelloResponse, MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck:
@@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
} }
} }
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message // MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This // The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
@@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message. // MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's // In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
@@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message. // UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) { func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp { if len(msg) < headerSizeHelloResp {
@@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
msg = append(msg, ab...)
if len(msg) > MaxHandshakeRespSize {
return nil, fmt.Errorf("invalid message length: %d", len(msg))
}
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
}
// MarshalCloseMsg creates a close message. // MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the // The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection. // client can send this message. After receiving this message, the server or client will close the connection.

View File

@@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) {
} }
} }
func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalAuthMsg(peerID, []byte{})
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")

View File

@@ -103,7 +103,7 @@ func (m *Metrics) PeerActivity(peerID string) {
select { select {
case m.peerActivityChan <- peerID: case m.peerActivityChan <- peerID:
default: default:
log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID) log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID)
} }
} }

View File

@@ -184,7 +184,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
stringPeerID := messages.HashIDToString(peerID) stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID) dp, ok := p.store.Peer(stringPeerID)
if !ok { if !ok {
p.log.Errorf("peer not found: %s", stringPeerID) p.log.Debugf("peer not found: %s", stringPeerID)
return return
} }

View File

@@ -2,7 +2,6 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@@ -14,7 +13,9 @@ import (
"github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address" "github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth" authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
) )
@@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
} }
if msgType != messages.MsgTypeHello { var (
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr()) responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
} }
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
if err != nil { if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err) return nil, err
} }
authMsg, err := authmsg.UnmarshalMsg(authData) _, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("unmarshal auth message: %w", err)
}
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
msg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
_, err = conn.Write(msg)
if err != nil { if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
} }
return peerID, nil return peerID, nil
} }
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}

View File

@@ -300,11 +300,13 @@ install_netbird() {
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service # Load and start netbird service
if ! ${SUDO} netbird service install 2>&1; then if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
echo "NetBird service has already been loaded" if ! ${SUDO} netbird service install 2>&1; then
fi echo "NetBird service has already been loaded"
if ! ${SUDO} netbird service start 2>&1; then fi
echo "NetBird service has already been started" if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
fi
fi fi

View File

@@ -5,6 +5,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@@ -12,6 +13,8 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
const defaultLogSize = 5
// InitLog parses and sets log-level input // InitLog parses and sets log-level input
func InitLog(logLevel string, logPath string) error { func InitLog(logLevel string, logPath string) error {
level, err := log.ParseLevel(logLevel) level, err := log.ParseLevel(logLevel)
@@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err) log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err return err
} }
customOutputs := []string{"console", "syslog"}; customOutputs := []string{"console", "syslog"}
if logPath != "" && !slices.Contains(customOutputs, logPath) { if logPath != "" && !slices.Contains(customOutputs, logPath) {
maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{ lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic // Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath), Filename: filepath.ToSlash(logPath),
MaxSize: 5, // MB MaxSize: maxLogSize, // MB
MaxBackups: 10, MaxBackups: 10,
MaxAge: 30, // days MaxAge: 30, // days
Compress: true, Compress: true,
@@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error {
log.SetLevel(level) log.SetLevel(level)
return nil return nil
} }
func getLogMaxSize() int {
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
size, err := strconv.ParseInt(sizeVar, 10, 64)
if err != nil {
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
return defaultLogSize
}
log.Infof("Setting log file max size to %d MB", size)
return int(size)
}
return defaultLogSize
}