Compare commits

...

35 Commits

Author SHA1 Message Date
Pascal Fischer
ad8459ea2f add mysql support [WIP] 2024-09-27 13:44:50 +02:00
Zoltan Papp
4ebf6e1c4c [client] Close the remote conn in proxy (#2626)
Port the conn close call to eBPF proxy
2024-09-25 18:50:10 +02:00
pascal-fischer
1e4a0f77e2 Add get DB method to store (#2650) 2024-09-25 18:22:27 +02:00
Viktor Liu
b51d75204b [client] Anonymize relay address in status peers view (#2640) 2024-09-24 20:58:18 +02:00
Viktor Liu
e7d52c8c95 [client] Fix error count formatting (#2641) 2024-09-24 20:57:56 +02:00
Viktor Liu
ab82302c95 [client] Remove usage of custom dialer for localhost (#2639)
* Downgrade error log level for network monitor warnings

* Do not use custom dialer for localhost
2024-09-24 12:29:15 +02:00
pascal-fischer
d47be154ea [misc] Fix ip range posture check example (#2628) 2024-09-23 10:02:03 +02:00
Bethuel Mmbaga
35c892aea3 [management] Restrict accessible peers to user-owned peers for non-admins (#2618)
* Restrict accessible peers to user-owned peers for non-admin users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add service user test

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* reuse account from token

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* return error when peer not found

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-20 12:36:58 +03:00
Zoltan Papp
fc4b37f7bc Exit from processConnResults after all tries (#2621)
* Exit from processConnResults after all tries

If all server is unavailable then the server picker never return
because we never close the result channel.
Count the number of the results and exit when we reached the
expected size
2024-09-19 13:49:28 +02:00
Zoltan Papp
6f0fd1d1b3 - Increase queue size and drop the overflowed messages (#2617)
- Explicit close the net.Conn in user space wgProxy when close the wgProxy
- Add extra logs
2024-09-19 13:49:09 +02:00
Zoltan Papp
28cbb4b70f [client] Cancel the context of wg watcher when the go routine exit (#2612) 2024-09-17 12:10:17 +02:00
Zoltan Papp
1104c9c048 [client] Fix race condition while read/write conn status in peer conn (#2607) 2024-09-17 11:15:14 +02:00
Maycon Santos
5bc601111d [relay] Add health check attempt threshold (#2609)
* Add health check attempt threshold for receiver

* Add health check attempt threshold for sender
2024-09-17 10:04:17 +02:00
Zoltan Papp
b74951f29e [client] Enforce permissions on Win (#2568)
Enforce folder permission on Windows, giving only administrators and system access to the NetBird folder.
2024-09-16 22:42:37 +02:00
Zoltan Papp
97e10e440c Fix leaked server connections (#2596)
Fix leaked server connections

close unused connections in the client lib
close deprecated connection in the server lib
The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes.

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Add logging

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-16 16:11:10 +02:00
pascal-fischer
6c50b0c84b [management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
2024-09-16 15:47:03 +02:00
pascal-fischer
730dd1733e [signal] Fix signal active peers metrics (#2591) 2024-09-15 16:46:55 +02:00
Bethuel Mmbaga
82739e2832 [management] fix legacy decrypting of empty values (#2595)
* allow legacy decrypting on empty values

* validate source size and padding limits

* added tests

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-15 16:22:46 +02:00
Maycon Santos
fa7767e612 Fix get management and signal state race condition (#2570)
* Fix get management and signal state race condition

* fix get full status lock
2024-09-15 16:07:26 +02:00
benniekiss
f1171198de [management] Add command flag to set metrics port for signal and relay service, and update management port (#2599)
* add flags to customize metrics port for relay and signal

* change management default metrics port to match other services
2024-09-14 10:34:32 +02:00
Zoltan Papp
9e041b7f82 Fix blocked net.Conn Close call (#2600) 2024-09-14 10:27:37 +02:00
Zoltan Papp
b4c8cf0a67 Change heartbeat timeout (#2598) 2024-09-14 10:12:54 +02:00
Carlos Hernandez
1ef51a4ffa [client] Ensure engine is stopped before starting it back (#2565)
Before starting a new instance of the engine, check if it is nil and stop the current instance
2024-09-13 16:46:59 +02:00
Maycon Santos
f6d57e7a96 [misc] Support configurable max log size with var NB_LOG_MAX_SIZE_MB (#2592)
* Support configurable max log size with var NB_LOG_MAX_SIZE_MB

* add better logs
2024-09-12 19:56:55 +02:00
Zoltan Papp
ab892b8cf9 Fix wg handshake checking (#2590)
* Fix wg handshake checking

* Ensure in the initial handshake reading

* Change the handshake period
2024-09-12 19:18:02 +02:00
Gianluca Boiano
33c9b2d989 fix: install.sh: avoid call of netbird executable after rpm installation (#2589) 2024-09-12 17:32:47 +02:00
Bethuel Mmbaga
170e842422 [management] Add accessible peers endpoint (#2579)
* move accessible peer to separate endpoint in api doc

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add endpoint to get accessible peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/peers_handler.go

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
2024-09-12 16:19:27 +03:00
Maycon Santos
4c130a0291 Update Go version to 1.23 (#2588) 2024-09-12 13:46:28 +02:00
Maycon Santos
afb9673bc4 [misc] Update core github actions (#2584) 2024-09-11 21:49:05 +02:00
Bethuel Mmbaga
cf6210a6f4 [management] Add GCM encryption and migrate legacy encrypted events (#2569)
* Add AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* migrate legacy encrypted data to AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor and use transaction when migrating data

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add events migration tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip migrating record on error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preallocate capacity for nonce to avoid allocations in Seal

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-11 20:09:57 +03:00
Maycon Santos
c59a39d27d Update service package version (#2582) 2024-09-11 19:05:10 +02:00
Maycon Santos
47adb976f8 Remove pre-release step from workflow (#2583) 2024-09-11 18:59:19 +02:00
Zoltan Papp
9cfc8f8aa4 [relay] change log levels (#2580) 2024-09-11 18:36:19 +02:00
Viktor Liu
2d1bf3982d [relay] Improve relay messages (#2574)
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2024-09-11 16:20:30 +02:00
Viktor Liu
50ebbe482e [client] Don't overwrite allowed IPs when updating the wg peer's endpoint address (#2578)
This will fix broken routes on routing clients when upgrading/downgrading from/to relayed connections.
2024-09-11 16:05:13 +02:00
97 changed files with 4129 additions and 1127 deletions

View File

@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -124,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
id: go
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
@@ -49,4 +49,4 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
version: latest
args: --timeout=12m
args: --timeout=12m

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: run install script
env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v3
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
@@ -62,4 +62,4 @@ jobs:
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0
CGO_ENABLED: 0

View File

@@ -36,18 +36,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
-
name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -93,28 +93,28 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -133,17 +133,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -176,7 +176,7 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@@ -189,18 +189,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
-
name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -225,7 +225,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/

View File

@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,10 +219,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v3
- name: handle insisting image # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -259,9 +256,6 @@ jobs:
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: handle insisting image gen CockroachDB # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:

View File

@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}

View File

@@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))

View File

@@ -117,6 +117,11 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = WriteOutConfig(input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input)
}

View File

@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
}
defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.isContextCancelled() {
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -279,9 +286,10 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
}
<-engineCtx.Done()

View File

@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
}
func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
}
func (e *Engine) startNetworkMonitor() {
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
nsGroupStates[i].Enabled = false
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {

View File

@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing routing message: %v", err)
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}

View File

@@ -89,8 +89,8 @@ type Conn struct {
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string)
statusRelay ConnStatus
statusICE ConnStatus
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
currentConnPriority ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
statusRelay: StatusDisconnected,
statusICE: StatusDisconnected,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE == StatusDisconnected {
if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
conn.statusICE = StatusConnected
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
@@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay
}
changed := conn.statusICE != newState && newState != StatusConnecting
conn.statusICE = newState
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
select {
case conn.iCEDisconnected <- changed:
@@ -518,18 +518,22 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay = StatusConnected
conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
@@ -538,7 +542,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE == StatusConnected {
if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
@@ -559,8 +563,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
wgConfigWorkaround()
conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
@@ -594,8 +598,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.wgProxyRelay = nil
}
changed := conn.statusRelay != StatusDisconnected
conn.statusRelay = StatusDisconnected
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
select {
case conn.relayDisconnected <- changed:
@@ -661,8 +665,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
}
func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay = StatusDisconnected
conn.statusICE = StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
conn.statusICE.Set(StatusDisconnected)
peerState := State{
PubKey: conn.config.Key,
@@ -706,7 +710,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
}
func (conn *Conn) isRelayed() bool {
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
return false
}
@@ -718,11 +722,11 @@ func (conn *Conn) isRelayed() bool {
}
func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
return StatusConnected
}
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting
}
@@ -733,12 +737,12 @@ func (conn *Conn) isConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
return false
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay != StatusConnected {
if conn.statusRelay.Get() != StatusConnected {
return false
}
}
@@ -771,13 +775,12 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
err = wgProxy.CloseConn()
if err != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
}
return nil, nil, err
}

View File

@@ -1,6 +1,10 @@
package peer
import log "github.com/sirupsen/logrus"
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
// StatusConnected indicate the peer is in connected state
@@ -12,7 +16,34 @@ const (
)
// ConnStatus describe the status of a peer's connection
type ConnStatus int
type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string {
switch s {

View File

@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")

View File

@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
return d.nsGroupStates
}
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
defer d.mux.Unlock()
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status)
}

View File

@@ -14,7 +14,7 @@ import (
)
var (
wgHandshakePeriod = 2 * time.Minute
wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second
)
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
}
ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
w.wgStateCheck(ctx, ctxCancel)
}
func (w *WorkerRelay) DisableWgWatcher() {
@@ -157,37 +157,51 @@ func (w *WorkerRelay) CloseConn() {
}
}
// wgStateCheck help to check the state of the wireguard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
expected := wgHandshakeOvertime
for {
select {
case <-timer.C:
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
if time.Since(lastHandshake) > expected {
w.log.Infof("Wireguard handshake timed out, closing relay connection")
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
}
}(lastHandshake)
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {

View File

@@ -1,4 +1,4 @@
package wgproxy
package ebpf
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package wgproxy
package ebpf
import (
"fmt"

View File

@@ -1,6 +1,6 @@
//go:build linux && !android
package wgproxy
package ebpf
import (
"context"
@@ -13,47 +13,49 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
loopbackAddr = "127.0.0.1"
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn transport.UDPConn
lastUsedPort uint16
rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
ctxCancel context.CancelFunc
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy
}
// listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) listen() error {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
@@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{
Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"),
IP: net.ParseIP(loopbackAddr),
}
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
@@ -91,108 +95,112 @@ func (p *WGEBPFProxy) listen() error {
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(wgEndpointPort, turnConn)
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
// Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy")
var err1, err2, err3 error
if p.conn != nil {
err1 = p.conn.Close()
if p.ctx != nil && p.ctx.Err() != nil {
//nolint
return nil
}
err2 = p.ebpfManager.FreeWGProxy()
if p.rawConn != nil {
err3 = p.rawConn.Close()
p.ctxCancel()
var result *multierror.Error
if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err)
}
if err1 != nil {
return err1
if err := p.ebpfManager.FreeWGProxy(); err != nil {
result = multierror.Append(result, err)
}
if err2 != nil {
return err2
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
}
return err3
return nberrors.FormatErrorOrNil(result)
}
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
var err error
defer func() {
p.removeTurnConn(endpointPort)
}()
for {
select {
case <-p.ctx.Done():
return
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil {
return
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
log.Errorf("failed to proxy packet to remote conn: %s", err)
}
}
}
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
if p.ctx.Err() == nil {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
}
return nil
}
if _, err := conn.Write(buf[:n]); err != nil {
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
}
return nil
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
@@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
_, ok := p.turnConnStore[turnConnID]
if ok {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
}
delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {

View File

@@ -1,14 +1,13 @@
//go:build linux && !android
package wgproxy
package ebpf
import (
"context"
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)

View File

@@ -0,0 +1,44 @@
//go:build linux && !android
package ebpf
import (
"context"
"fmt"
"net"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}

View File

@@ -1,22 +0,0 @@
package wgproxy
import "context"
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy(ctx context.Context) Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(ctx, w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.Free()
}
return nil
}

View File

@@ -3,20 +3,26 @@
package wgproxy
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
)
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
type Factory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
err := ebpfProxy.Listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy
return f
}
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
p := &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return p
}
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -2,8 +2,20 @@
package wgproxy
import "context"
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
type Factory struct {
wgPort int
}
func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}
func (w *Factory) GetProxy() Proxy {
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
return nil
}

View File

@@ -1,12 +1,12 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the Turn connection and the WireGuard
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(turnConn net.Conn) (net.Addr, error)
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error
Free() error
}

View File

@@ -0,0 +1,128 @@
//go:build linux
package wgproxy
import (
"context"
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
type mocConn struct {
closeChan chan struct{}
closed bool
}
func newMockConn() *mocConn {
return &mocConn{
closeChan: make(chan struct{}),
}
}
func (m *mocConn) Read(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Close() error {
if m.closed == true {
return nil
}
m.closed = true
close(m.closeChan)
return nil
}
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
}
}
func (m *mocConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830),
},
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,120 +0,0 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(ctx)
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if err == io.EOF {
p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
}
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}
}

View File

@@ -0,0 +1,146 @@
package usp
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
closeMu sync.Mutex
closed bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUserSpaceProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
// prevent double close
if p.closed {
return nil
}
p.closed = true
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return errors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to remote conn: %s", err)
return
}
}
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}

8
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird
go 1.21.0
go 1.23.0
require (
cunicu.li/go-rosenpass v0.4.0
@@ -95,9 +95,10 @@ require (
golang.org/x/term v0.21.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
gorm.io/gorm v1.25.7
nhooyr.io/websocket v1.8.11
)
@@ -151,6 +152,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-text/render v0.1.0 // indirect
github.com/go-text/typesetting v0.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
@@ -232,7 +234,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect
)
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949

12
go.sum
View File

@@ -238,6 +238,8 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
@@ -523,8 +525,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
@@ -1224,12 +1226,14 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g=
gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=

View File

@@ -56,8 +56,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,

View File

@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,

View File

@@ -54,7 +54,7 @@ func Execute() error {
func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")

View File

@@ -263,6 +263,11 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
@@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
@@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
@@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {

View File

@@ -6,13 +6,14 @@ import (
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"errors"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
gcm cipher.AEAD
}
func GenerateKey() (string, error) {
@@ -35,14 +36,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{
block: block,
gcm: gcm,
}
return ec, nil
}
func (ec *FieldEncrypt) Encrypt(payload string) string {
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
@@ -50,7 +58,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText)
}
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
// Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
@@ -65,17 +88,49 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil
}
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
if srcLen == 0 {
return nil, errors.New("input data is empty")
}
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -1,6 +1,7 @@
package sqlite
import (
"bytes"
"testing"
)
@@ -15,7 +16,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -30,6 +35,32 @@ func TestGenerateKey(t *testing.T) {
}
}
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
@@ -41,7 +72,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -61,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,157 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
if _, err := db.Exec(createTableQuery); err != nil {
return err
}
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
return err
}
if err := updateDeletedUsersTable(ctx, db); err != nil {
return fmt.Errorf("failed to update deleted_users table: %v", err)
}
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
}
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
exists, err := checkColumnExists(db, "deleted_users", "name")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
}
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
}
return nil
}
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
defer func() {
_ = tx.Rollback()
}()
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
if err != nil {
return fmt.Errorf("failed to execute select query: %v", err)
}
defer rows.Close()
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
if err != nil {
return fmt.Errorf("failed to prepare update statement: %v", err)
}
defer updateStmt.Close()
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
return nil
}
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
for rows.Next() {
var (
id, decryptedEmail, decryptedName string
email, name *string
)
err := rows.Scan(&id, &email, &name)
if err != nil {
return err
}
if email != nil {
decryptedEmail, err = crypt.LegacyDecrypt(*email)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt email: %w", err),
)
continue
}
}
if name != nil {
decryptedName, err = crypt.LegacyDecrypt(*name)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt name: %w", err),
)
continue
}
}
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
if err != nil {
return fmt.Errorf("failed to encrypt email: %w", err)
}
encryptedName, err := crypt.Encrypt(decryptedName)
if err != nil {
return fmt.Errorf("failed to encrypt name: %w", err)
}
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
if err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,84 @@
package sqlite
import (
"context"
"database/sql"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/require"
)
func setupDatabase(t *testing.T) *sql.DB {
t.Helper()
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
require.NoError(t, err, "Failed to open database")
t.Cleanup(func() {
_ = db.Close()
})
_, err = db.Exec(createTableQuery)
require.NoError(t, err, "Failed to create events table")
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
require.NoError(t, err, "Failed to create deleted_users table")
return db
}
func TestMigrate(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
legacyName := crypt.LegacyEncrypt("Test Account")
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
require.NoError(t, err, "Failed to insert event")
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
require.NoError(t, err, "Failed to insert legacy encrypted data")
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists")
require.False(t, colExists, "enc_algo column should not exist before migration")
err = migrate(context.Background(), crypt, db)
require.NoError(t, err, "Migration failed")
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
require.True(t, colExists, "enc_algo column should exist after migration")
var encAlgo string
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
require.NoError(t, err, "Failed to select updated data")
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
store, err := createStore(crypt, db)
require.NoError(t, err, "Failed to create store")
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
require.NoError(t, err, "Failed to get events")
require.Len(t, events, 1, "Should have one event")
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
}

View File

@@ -26,7 +26,7 @@ const (
"meta TEXT," +
" target_id TEXT);"
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events
@@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
gcmEncAlgo = "GCM"
)
// Store is the implementation of the activity.Store interface backed by SQLite
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err
}
_, err = db.Exec(createTableQuery)
if err != nil {
if err = migrate(ctx, crypt, db); err != nil {
_ = db.Close()
return nil, err
return nil, fmt.Errorf("events database migration: %w", err)
}
_, err = db.Exec(creatTableDeletedUsersQuery)
if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
s := &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
return createStore(crypt, db)
}
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil
}
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
if err != nil {
return nil, err
}
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
if err != nil {
return nil, err
}
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil {
return nil, err
}
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil
}
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
log.WithContext(ctx).Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
// createStore initializes and returns a new Store instance with prepared SQL statements.
func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
return err
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
return &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}, nil
}
// checkColumnExists checks if a column exists in a specified table
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
rows, err := db.Query(query)
if err != nil {
return false, fmt.Errorf("failed to query table info: %w", err)
}
defer rows.Close()
found := false
for rows.Next() {
var (
cid int
name string
dataType string
notNull int
dfltVal sql.NullString
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
var cid int
var name, ctype string
var notnull, pk int
var dfltValue sql.NullString
err = rows.Scan(&cid, &name, &ctype, &notnull, &dfltValue, &pk)
if err != nil {
return err
return false, fmt.Errorf("failed to scan row: %w", err)
}
if name == "name" {
found = true
break
if name == columnName {
return true, nil
}
}
err = rows.Err()
if err != nil {
return err
if err = rows.Err(); err != nil {
return false, err
}
if found {
return nil
}
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
return false, nil
}

View File

@@ -7,6 +7,7 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type MockStore struct {
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}
return nil, fmt.Errorf("account not found")
return nil, status.NewPeerNotFoundError(peerId)
}
type MocAccountManager struct {

View File

@@ -2,6 +2,8 @@ package server
import (
"context"
"errors"
"net"
"os"
"path/filepath"
"strings"
@@ -46,6 +48,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return err
}
account.SetupKeys[setupKeyID].UsedTimes++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}
allGroup.Peers = append(allGroup.Peers, peerID)
return nil
}
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountId)
if err != nil {
return err
}
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
return nil
}
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}
account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}
account.Network.Serial++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return setupKey, nil
}
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps, nil
}
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Network, nil
}
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
@@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
@@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return account, nil
@@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return "", status.NewSetupKeyNotFoundError()
}
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()

View File

@@ -251,7 +251,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
PeerBase:
Peer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
- type: object
@@ -378,25 +378,40 @@ components:
description: User ID of the user that enrolled this peer
type: string
example: google-oauth2|277474792786460067937
os:
description: Peer's operating system and version
type: string
example: linux
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
geoname_id:
description: Unique identifier from the GeoNames database for a specific geographical location.
type: integer
example: 2643743
connected:
description: Peer to Management connection status
type: boolean
example: true
last_seen:
description: Last time peer connected to Netbird's management service
type: string
format: date-time
example: "2023-05-05T10:05:26.420578Z"
required:
- ip
- dns_label
- user_id
Peer:
allOf:
- $ref: '#/components/schemas/PeerBase'
- type: object
properties:
accessible_peers:
description: List of accessible peers
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
required:
- accessible_peers
- os
- country_code
- city_name
- geoname_id
- connected
- last_seen
PeerBatch:
allOf:
- $ref: '#/components/schemas/PeerBase'
- $ref: '#/components/schemas/Peer'
- type: object
properties:
accessible_peers_count:
@@ -935,7 +950,7 @@ components:
type: array
items:
type: string
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action:
description: Action to take upon policy match
type: string
@@ -1806,6 +1821,38 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/accessible-peers:
get:
summary: List accessible Peers
description: Returns a list of peers that the specified peer can connect to within the network.
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: A JSON Array of Accessible Peers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup-keys:
get:
summary: List all Setup Keys

View File

@@ -152,18 +152,36 @@ const (
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
}
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer.
type Peer struct {
// AccessiblePeers List of accessible peers
AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp string `json:"connection_ip"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion string `json:"kernel_version"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
// LoginExpired Indicates whether peer's login expired or not
LoginExpired bool `json:"login_expired"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// SerialNumber System serial number
SerialNumber string `json:"serial_number"`
// SshEnabled Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// UiVersion Peer's desktop UI version
UiVersion string `json:"ui_version"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
// Version Peer's daemon or cli version
Version string `json:"version"`
}
// PeerBase defines model for PeerBase.
type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`

View File

@@ -115,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {

View File

@@ -7,8 +7,6 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -71,12 +70,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -117,13 +112,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
}
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -220,32 +211,81 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
}
}
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
return
}
if peer.UserID != user.Id {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
for _, p := range netMap.OfflinePeers {
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
return accessiblePeers
}
func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
return api.AccessiblePeer{
CityName: peer.Location.CityName,
Connected: peer.Status.Connected,
CountryCode: peer.Location.CountryCode,
DnsLabel: fqdn(peer, dnsDomain),
GeonameId: int(peer.Location.GeoNameID),
Id: peer.ID,
Ip: peer.IP.String(),
LastSeen: peer.Status.LastSeen,
Name: peer.Name,
Os: peer.Meta.OS,
UserId: peer.UserID,
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
@@ -270,7 +310,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
@@ -296,7 +336,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
@@ -12,20 +13,30 @@ import (
"time"
"github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
type ctxKey string
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
userIDKey ctxKey = "user_id"
)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: claims.AccountId,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{
peers[0].ID: peers[0],
peers[1].ID: peers[1],
},
Peers: peersMap,
Users: map[string]*server.User{
"test_user": user,
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: claims.AccountId,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
Serial: 51,
},
}, user, nil
}
return account, account.Users[claims.UserId], nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
})
}
}
func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
}
peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
}
peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
}{
{
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
},
{
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "service user can access unowned peer",
peerID: "peer3",
callerUserID: serviceUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
defer res.Body.Close()
var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
}
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
})
}
}

View File

@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" {
t.Skip("Skipping on CI")
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping test on CI or Windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
{"XXL", 2000, 1},
{"XXL", 5000, 1},
}
log.SetOutput(io.Discard)
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
}
defer mgmtServer.GracefulStop()
t.Logf("management setup complete, start registering peers")
var counter int32
var counterStart int32
var wg sync.WaitGroup
var wgAccount sync.WaitGroup
var mu sync.Mutex
messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ {
wg.Add(1)
wgAccount.Add(1)
var wgPeer sync.WaitGroup
go func(j int, counter *int32, counterStart *int32) {
defer wg.Done()
defer wgAccount.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil {
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
return
}
startTime := time.Now()
for i := 0; i < bc.peers; i++ {
wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Logf("failed to generate key: %v", err)
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock()
messageCalls = append(messageCalls, login)
mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
go func(peerLogin PeerLogin, counterStart *int32) {
defer wgPeer.Done()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}(peerLogin, counterStart)
}
wgPeer.Wait()
t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart)
}
wg.Wait()
wgAccount.Wait()
t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls)

View File

@@ -11,6 +11,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto"
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = account.FindPeerByPubKey(peer.Key)
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
AccountID: account.Id,
AccountID: accountID,
}
var ephemeral bool
setupKeyName := ""
if !addedByUser {
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return nil, nil, nil, err
}
var newPeer *nbpeer.Peer
if !sk.IsValid() {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
account.SetupKeys[sk.Key] = sk.IncrementUsage()
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
ephemeral = sk.Ephemeral
setupKeyName = sk.Name
} else {
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
}
takenIps := account.getTakenIPs()
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, nil, err
}
registrationTime := time.Now().UTC()
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
Key: peer.Key,
SetupKey: upperKey,
IP: nextIp,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: newLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string
var setupKeyName string
var ephemeral bool
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
// Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
var groupsToAdd []string
if addedByUser {
groupsToAdd, err = account.getUserGroups(userID)
if err != nil {
return nil, nil, nil, err
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd = sk.AutoGroups
ephemeral = sk.Ephemeral
setupKeyID = sk.Id
setupKeyName = sk.Name
}
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, nil, err
}
}
if len(groupsToAdd) > 0 {
for _, s := range groupsToAdd {
if g, ok := account.Groups[s]; ok && g.Name != "All" {
g.Peers = append(g.Peers, newPeer.ID)
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
if addedByUser {
user, err := account.FindUser(userID)
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
return fmt.Errorf("failed to get free DNS label: %w", err)
}
user.updateLastLogin(newPeer.LastLogin)
}
account.Peers[newPeer.ID] = newPeer
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
}
// Account is saved, we can release the lock
unlock()
unlock = nil
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, peer)
postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err
}
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}

View File

@@ -7,20 +7,24 @@ import (
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
)
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}
func Test_RegisterPeerByUser(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
UserID: existingUserID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
LastLogin: time.Now(),
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
}
func Test_RegisterPeerBySetupKey(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
assert.Equal(t, uint64(0), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
@@ -14,6 +15,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
@@ -33,6 +35,7 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -415,13 +418,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
@@ -474,15 +476,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
return nil, status.NewUserNotFoundError(userID)
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@@ -535,7 +537,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -595,7 +597,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -612,12 +614,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -631,12 +632,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -650,12 +650,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@@ -677,61 +676,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
return "", status.NewSetupKeyNotFoundError()
}
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return accountID, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
return status.NewUserNotFoundError(userID)
}
return status.Errorf(status.Internal, "issue getting user from store")
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
return s.db.Save(user).Error
return s.db.Save(&user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -790,6 +845,16 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
}
// NewMysqlStore creates a new MySql store.
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(mysql.Open(dsn), getGormConfig())
if err != nil {
return nil, err
}
return NewSqlStore(ctx, db, MySqlStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
@@ -807,6 +872,15 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store,
return NewPostgresqlStore(ctx, dsn, metrics)
}
// newMySqlStore initializes a new MySql store.
func newMySqlStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
dsn, ok := os.LookupEnv(mySqlDsnEnv)
if !ok {
return nil, fmt.Errorf("%s is not set", mySqlDsnEnv)
}
return NewMysqlStore(ctx, dsn, metrics)
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewSqliteStore(ctx, dataDir, metrics)
@@ -850,3 +924,127 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}

View File

@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []net.IP{}, takenIPs)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip1 := net.IP{1, 1, 1, 1}.To16()
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip := net.IP{100, 64, 0, 0}.To16()
assert.Equal(t, ip, network.Net.IP)
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
assert.Equal(t, "", network.Dns)
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
assert.Equal(t, uint64(0), network.Serial)
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

View File

@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}

View File

@@ -27,6 +27,15 @@ import (
"github.com/netbirdio/netbird/route"
)
type LockingStrength string
const (
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
)
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
@@ -41,7 +50,7 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
@@ -60,14 +69,24 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}
type StoreEngine string
@@ -76,8 +95,10 @@ const (
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
PostgresStoreEngine StoreEngine = "postgres"
MySqlStoreEngine StoreEngine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mySqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
func getStoreEngineFromEnv() StoreEngine {
@@ -88,11 +109,12 @@ func getStoreEngineFromEnv() StoreEngine {
}
value := StoreEngine(strings.ToLower(kind))
if value == SqliteStoreEngine || value == PostgresStoreEngine {
switch value {
case SqliteStoreEngine, PostgresStoreEngine, MySqlStoreEngine:
return value
default:
return SqliteStoreEngine
}
return SqliteStoreEngine
}
// getStoreEngine determines the store engine to use.
@@ -139,6 +161,9 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel
case PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics)
case MySqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine")
return newMySqlStore(ctx, metrics)
default:
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
}

View File

@@ -0,0 +1,120 @@
{
"Accounts": {
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
"CreatedBy": "",
"Domain": "test.com",
"DomainCategory": "private",
"IsDomainPrimaryAccount": true,
"SetupKeys": {
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"Name": "Default key",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["cfefqs706sqkneg59g2g"],
"UsageLimit": 0,
"Ephemeral": false
},
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"Name": "Faulty key with non existing group",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["abcd"],
"UsageLimit": 0,
"Ephemeral": false
}
},
"Network": {
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": "",
"Serial": 0
},
"Peers": {},
"Users": {
"edafee4e-63fb-11ec-90d6-0242ac120003": {
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "admin",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": ["cfefqs706sqkneg59g3g"],
"PATs": {},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
},
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "user",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": null,
"PATs": {
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
"UserID": "",
"Name": "",
"HashedToken": "SoMeHaShEdToKeN",
"ExpirationDate": "2023-02-27T00:00:00Z",
"CreatedBy": "user",
"CreatedAt": "2023-01-01T00:00:00Z",
"LastUsed": "2023-02-01T00:00:00Z"
}
},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
}
},
"Groups": {
"cfefqs706sqkneg59g4g": {
"ID": "cfefqs706sqkneg59g4g",
"Name": "All",
"Peers": []
},
"cfefqs706sqkneg59g3g": {
"ID": "cfefqs706sqkneg59g3g",
"Name": "AwesomeGroup1",
"Peers": []
},
"cfefqs706sqkneg59g2g": {
"ID": "cfefqs706sqkneg59g2g",
"Name": "AwesomeGroup2",
"Peers": []
}
},
"Rules": null,
"Policies": [],
"Routes": null,
"NameServerGroups": null,
"DNSSettings": null,
"Settings": {
"PeerLoginExpirationEnabled": false,
"PeerLoginExpiration": 86400000000000,
"GroupsPropagationEnabled": false,
"JWTGroupsEnabled": false,
"JWTGroupsClaimName": ""
}
}
},
"InstallationID": ""
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"fmt"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
const defaultDuration = 12 * time.Hour
@@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct {
turnCfg *TURNConfig
relayCfg *Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
@@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
duration = defaultDuration
}
mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration)
hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
var err error
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
log.Errorf("failed to create relay token generator: %s", err)
}
}
return mgr
@@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
}
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
if err != nil {
return nil, fmt.Errorf("failed to generate TURN token: %s", err)
return nil, fmt.Errorf("generate TURN token: %s", err)
}
return (*Token)(turnToken), nil
}
@@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
if m.relayHmacToken == nil {
return nil, fmt.Errorf("relay configuration is not set")
}
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil {
return nil, fmt.Errorf("failed to generate relay token: %s", err)
return nil, fmt.Errorf("generate relay token: %s", err)
}
return (*Token)(relayToken), nil
return &Token{
Payload: string(relayToken.Payload),
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
}, nil
}
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
@@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil {
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
return
@@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe
WiretrusteeConfig: &proto.WiretrusteeConfig{
Relay: &proto.RelayConfig{
Urls: m.relayCfg.Addresses,
TokenPayload: relayToken.Payload,
TokenSignature: relayToken.Signature,
TokenPayload: string(relayToken.Payload),
TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
},
// omit Turns to avoid updates there
},

View File

@@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Errorf("expected generated relay signature not to be empty, got empty")
}
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret))
hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
}
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {

View File

@@ -70,7 +70,7 @@ type User struct {
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
LastLogin time.Time `gorm:"type:TIMESTAMP;null;default:null"`
// CreatedAt records the time the user was created
CreatedAt time.Time
@@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
func (u *User) updateLastLogin(login time.Time) {
u.LastLogin = login
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
@@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}

View File

@@ -1,12 +1,14 @@
package allow
import "hash"
// Auth is a Validator that allows all connections.
// Used this for testing purposes only.
type Auth struct {
}
func (a *Auth) Validate(func() hash.Hash, any) error {
func (a *Auth) Validate(any) error {
return nil
}
func (a *Auth) ValidateHelloMsgType(any) error {
return nil
}

View File

@@ -1,9 +1,11 @@
package hmac
import (
"encoding/base64"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// TokenStore is a simple in-memory store for token
@@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
return nil
}
t, err := marshalToken(*token)
sig, err := base64.StdEncoding.DecodeString(token.Signature)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return err
return fmt.Errorf("decode signature: %w", err)
}
a.token = t
tok := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: sig,
Payload: []byte(token.Payload),
}
a.token = tok.Marshal()
return nil
}

View File

@@ -18,17 +18,6 @@ type Token struct {
Signature string
}
func marshalToken(token Token) ([]byte, error) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(token)
if err != nil {
log.Debugf("failed to marshal token: %s", err)
return nil, fmt.Errorf("failed to marshal token: %w", err)
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)

View File

@@ -0,0 +1,40 @@
package v2
import (
"crypto/sha256"
"hash"
)
const (
AuthAlgoUnknown AuthAlgo = iota
AuthAlgoHMACSHA256
)
type AuthAlgo uint8
func (a AuthAlgo) String() string {
switch a {
case AuthAlgoHMACSHA256:
return "HMAC-SHA256"
default:
return "Unknown"
}
}
func (a AuthAlgo) New() func() hash.Hash {
switch a {
case AuthAlgoHMACSHA256:
return sha256.New
default:
return nil
}
}
func (a AuthAlgo) Size() int {
switch a {
case AuthAlgoHMACSHA256:
return sha256.Size
default:
return 0
}
}

View File

@@ -0,0 +1,45 @@
package v2
import (
"crypto/hmac"
"fmt"
"hash"
"strconv"
"time"
)
type Generator struct {
algo func() hash.Hash
algoType AuthAlgo
secret []byte
timeToLive time.Duration
}
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
algoFunc := algo.New()
if algoFunc == nil {
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
}
return &Generator{
algo: algoFunc,
algoType: algo,
secret: secret,
timeToLive: timeToLive,
}, nil
}
func (g *Generator) GenerateToken() (*Token, error) {
expirationTime := time.Now().Add(g.timeToLive).Unix()
payload := []byte(strconv.FormatInt(expirationTime, 10))
h := hmac.New(g.algo, g.secret)
h.Write(payload)
signature := h.Sum(nil)
return &Token{
AuthAlgo: g.algoType,
Signature: signature,
Payload: payload,
}, nil
}

View File

@@ -0,0 +1,110 @@
package v2
import (
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(token.Payload) == 0 {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
timeToLive := -1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,39 @@
package v2
import "errors"
type Token struct {
AuthAlgo AuthAlgo
Signature []byte
Payload []byte
}
func (t *Token) Marshal() []byte {
size := 1 + len(t.Signature) + len(t.Payload)
buf := make([]byte, size)
buf[0] = byte(t.AuthAlgo)
copy(buf[1:], t.Signature)
copy(buf[1+len(t.Signature):], t.Payload)
return buf
}
func UnmarshalToken(data []byte) (*Token, error) {
if len(data) == 0 {
return nil, errors.New("invalid token data")
}
algo := AuthAlgo(data[0])
sigSize := algo.Size()
if len(data) < 1+sigSize {
return nil, errors.New("invalid token data: insufficient length")
}
return &Token{
AuthAlgo: algo,
Signature: data[1 : 1+sigSize],
Payload: data[1+sigSize:],
}, nil
}

View File

@@ -0,0 +1,59 @@
package v2
import (
"crypto/hmac"
"errors"
"fmt"
"strconv"
"time"
)
const minLengthUnixTimestamp = 10
type Validator struct {
secret []byte
}
func NewValidator(secret []byte) *Validator {
return &Validator{secret: secret}
}
func (v *Validator) Validate(data any) error {
d, ok := data.([]byte)
if !ok {
return fmt.Errorf("invalid data type")
}
token, err := UnmarshalToken(d)
if err != nil {
return fmt.Errorf("unmarshal token: %w", err)
}
if len(token.Payload) < minLengthUnixTimestamp {
return errors.New("invalid payload: insufficient length")
}
hashFunc := token.AuthAlgo.New()
if hashFunc == nil {
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
}
h := hmac.New(hashFunc, v.secret)
h.Write(token.Payload)
expectedMAC := h.Sum(nil)
if !hmac.Equal(token.Signature, expectedMAC) {
return errors.New("invalid signature")
}
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timestamp {
return fmt.Errorf("expired token")
}
return nil
}

View File

@@ -1,8 +1,8 @@
package hmac
import (
"crypto/sha256"
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
@@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
}
}
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
func (a *TimedHMACValidator) Validate(credentials any) error {
b, ok := credentials.([]byte)
if !ok {
return fmt.Errorf("invalid credentials type")
@@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
log.Debugf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(algo, c)
return a.TimedHMAC.Validate(sha256.New, c)
}

View File

@@ -1,8 +1,35 @@
package auth
import "hash"
import (
"time"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
)
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(func() hash.Hash, any) error
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator
}
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
return &TimedHMACValidator{
authenticatorV2: authv2.NewValidator(secret),
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
return a.authenticatorV2.Validate(credentials)
}
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
return a.authenticator.Validate(credentials)
}

View File

@@ -14,8 +14,6 @@ import (
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/messages/address"
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
)
const (
@@ -60,37 +58,65 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct {
log *log.Entry
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
ctx context.Context
cancel context.CancelFunc
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
return &connContainer{
log: log,
conn: conn,
messages: messages,
ctx: ctx,
cancel: cancel,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
msg.Free()
return
}
cc.messages <- msg
select {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
}
func (cc *connContainer) close() {
cc.cancel()
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
close(cc.messages)
for msg := range cc.messages {
msg.Free()
}
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@@ -122,8 +148,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
@@ -136,11 +162,13 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
},
conns: make(map[string]*connContainer),
}
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
return c
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL)
c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
@@ -161,7 +189,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
c.log.Infof("relay connection established with: %s", c.connectionURL)
c.log.Infof("relay connection established")
return nil
}
@@ -183,11 +211,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
return nil, ErrConnAlreadyExists
}
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil
}
@@ -231,7 +259,7 @@ func (c *Client) connect() error {
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
c.log.Errorf("failed to close connection: %s", cErr)
}
return err
}
@@ -240,31 +268,21 @@ func (c *Client) connect() error {
}
func (c *Client) handShake() error {
authMsg := &auth2.Msg{
AuthAlgorithm: auth2.AlgoHMACSHA256,
AdditionalData: c.authTokenStore.TokenBinary(),
}
authData, err := authMsg.Marshal()
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
return fmt.Errorf("marshal auth message: %w", err)
}
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
c.log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
c.log.Errorf("failed to send auth message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeSize)
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
c.log.Errorf("failed to read auth response: %s", err)
return err
}
@@ -275,34 +293,29 @@ func (c *Client) handShake() error {
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
c.log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return err
}
addr, err := address.Unmarshal(additionalData)
if err != nil {
return fmt.Errorf("unmarshal address: %w", err)
}
c.muInstanceURL.Lock()
c.instanceURL = &RelayAddr{addr: addr.URL}
c.instanceURL = &RelayAddr{addr: addr}
c.muInstanceURL.Unlock()
return nil
}
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
@@ -314,6 +327,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
@@ -360,7 +374,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
log.Debugf("relay connection close by server")
c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
@@ -429,14 +443,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
c.log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
c.log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
@@ -450,12 +464,15 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it
if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
log.Errorf("failed to teardown connection: %s", err)
c.log.Errorf("failed to teardown connection: %s", err)
}
return
}
@@ -481,8 +498,9 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
container.close()
c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id)
container.close()
return nil
}
@@ -495,10 +513,12 @@ func (c *Client) close(gracefullyExit bool) error {
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
c.log.Warn("relay connection was already marked as not running")
return nil
}
c.serviceIsRunning = false
c.log.Infof("closing all peer connections")
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
@@ -506,8 +526,9 @@ func (c *Client) close(gracefullyExit bool) error {
err = c.relayConn.Close()
c.mu.Unlock()
c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL)
c.log.Infof("relay connection closed")
return err
}

View File

@@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
}
}
func TestCloseNotDrainedChannel(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
// the internal channel buffer size is 2. So we should overflow it
for i := 0; i < 5; i++ {
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
}
// wait for delivery
time.Sleep(1 * time.Second)
err = connBobToAlice.Close()
if err != nil {
t.Errorf("failed to close channel: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:

View File

@@ -3,7 +3,6 @@ package client
import (
"container/list"
"context"
"errors"
"fmt"
"net"
"reflect"
@@ -17,8 +16,6 @@ import (
var (
relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
@@ -92,67 +89,23 @@ func (m *Manager) Serve() error {
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
totalServers := len(m.serverURLs)
successChan := make(chan *Client, 1)
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
sp := ServerPicker{
TokenStore: m.tokenStore,
PeerID: m.peerID,
}
var errCount int
for {
select {
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
client, err := sp.PickServer(m.ctx, m.serverURLs)
if err != nil {
return err
}
}
m.relayClient = client
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
// TODO: abort the connection if another connection was successful
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
if err := relayClient.Connect(); err != nil {
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
return
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be

98
relay/client/picker.go Normal file
View File

@@ -0,0 +1,98 @@
package client
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
type connResult struct {
RelayClient *Client
Url string
Err error
}
type ServerPicker struct {
TokenStore *auth.TokenStore
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(urls)
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
go sp.processConnResults(connResultChan, successChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect()
resultChan <- connResult{
RelayClient: relayClient,
Url: url,
Err: err,
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
if hasSuccess {
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
if err := cr.RelayClient.Close(); err != nil {
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
}
continue
}
hasSuccess = true
successChan <- cr
}
close(successChan)
}

View File

@@ -0,0 +1,31 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}

View File

@@ -2,6 +2,7 @@ package cmd
import (
"context"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
@@ -16,21 +17,18 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util"
)
const (
metricsPort = 9090
)
type Config struct {
ListenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port
ExposedAddress string
MetricsPort int
LetsencryptEmail string
LetsencryptDataDir string
LetsencryptDomains []string
@@ -79,6 +77,7 @@ func init() {
cobraConfig = &Config{}
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
@@ -115,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err)
}
metricsServer, err := metrics.NewServer(metricsPort, "")
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)
@@ -139,7 +138,9 @@ func execute(cmd *cobra.Command, args []string) error {
}
srvListenerCfg.TLSConfig = tlsConfig
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
if err != nil {
log.Debugf("failed to create relay server: %v", err)

View File

@@ -3,10 +3,12 @@ package healthcheck
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
heartbeatTimeout = healthCheckInterval + 3*time.Second
heartbeatTimeout = healthCheckInterval + 10*time.Second
)
// Receiver is a healthcheck receiver
@@ -14,23 +16,26 @@ var (
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
OnTimeout chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
OnTimeout chan struct{}
log *log.Entry
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
attemptThreshold int
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver() *Receiver {
func NewReceiver(log *log.Entry) *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
OnTimeout: make(chan struct{}, 1),
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
OnTimeout: make(chan struct{}, 1),
log: log,
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
go r.waitForHealthcheck()
@@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() {
defer r.ctxCancel()
defer close(r.OnTimeout)
failureCounter := 0
for {
select {
case <-r.heartbeat:
r.alive = true
failureCounter = 0
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
failureCounter++
if failureCounter < r.attemptThreshold {
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
continue
}
r.notifyTimeout()
return
case <-r.ctx.Done():

View File

@@ -1,13 +1,18 @@
package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second
r := NewReceiver()
r := NewReceiver(log.WithContext(context.Background()))
select {
case <-r.OnTimeout:
@@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) {
func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second
r := NewReceiver()
r := NewReceiver(log.WithContext(context.Background()))
select {
case <-r.OnTimeout:
@@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second
r := NewReceiver()
r := NewReceiver(log.WithContext(context.Background()))
r.Heartbeat()
@@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) {
case <-time.After(3 * time.Second):
}
}
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
defer func() {
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
case <-receiver.OnTimeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -2,12 +2,21 @@ package healthcheck
import (
"context"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 5 * time.Second
healthCheckTimeout = 20 * time.Second
)
// Sender is a healthcheck sender
@@ -15,20 +24,25 @@ var (
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
ack chan struct{}
ack chan struct{}
alive bool
attemptThreshold int
}
// NewSender creates a new healthcheck sender
func NewSender() *Sender {
func NewSender(log *log.Entry) *Sender {
hc := &Sender{
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
log: log,
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
return hc
@@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
defer timeoutTimer.Stop()
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
failureCounter := 0
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
case <-timeoutTimer.C:
case <-timeoutTicker.C:
if hc.alive {
hc.alive = false
continue
}
failureCounter++
if failureCounter < hc.attemptThreshold {
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
continue
}
hc.Timeout <- struct{}{}
return
case <-hc.ack:
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
failureCounter = 0
hc.alive = true
case <-ctx.Done():
return
}
}
}
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,9 +2,12 @@ package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestMain(m *testing.M) {
@@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
select {
@@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) {
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) {
t.Fatalf("is not exited")
}
}
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
defer func() {
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx)
go func() {
responded := false
for {
select {
case <-ctx.Done():
return
case _, ok := <-sender.HealthCheck:
if !ok {
return
}
if tc.resetCounterOnce && !responded {
responded = true
sender.OnHCResponse()
}
}
}
}()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address
import (
@@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
}
return buf.Bytes(), nil
}
func Unmarshal(data []byte) (*Address, error) {
var addr Address
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&addr); err != nil {
return nil, fmt.Errorf("decode Address: %w", err)
}
return &addr, nil
}

View File

@@ -1,3 +1,4 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth
import (
@@ -30,15 +31,6 @@ type Msg struct {
AdditionalData []byte
}
func (msg *Msg) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(msg); err != nil {
return nil, fmt.Errorf("encode Msg: %w", err)
}
return buf.Bytes(), nil
}
func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg

View File

@@ -7,12 +7,21 @@ import (
)
const (
MsgTypeUnknown MsgType = 0
MsgTypeHello MsgType = 1
MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192
CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2
MsgTypeTransport MsgType = 3
MsgTypeClose MsgType = 4
MsgTypeHealthCheck MsgType = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1
SizeOfMsgType = 1
@@ -22,12 +31,12 @@ const (
sizeOfMagicByte = 4
headerSizeTransport = IDSize
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0
MaxHandshakeSize = 8192
CurrentProtocolVersion = 1
headerSizeAuth = sizeOfMagicByte + IDSize
headerSizeAuthResp = 0
)
var (
@@ -47,6 +56,10 @@ func (m MsgType) String() string {
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
@@ -58,10 +71,6 @@ func (m MsgType) String() string {
}
}
type HelloResponse struct {
InstanceAddress string
}
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
@@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
switch msgType {
case
MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
switch msgType {
case
MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck:
@@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
}
}
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
@@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return msg, nil
}
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
@@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
}
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
@@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
return msg, nil
}
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
@@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
return msg, nil
}
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
msg = append(msg, ab...)
if len(msg) > MaxHandshakeRespSize {
return nil, fmt.Errorf("invalid message length: %d", len(msg))
}
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
}
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.

View File

@@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) {
}
}
func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalAuthMsg(peerID, []byte{})
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")

View File

@@ -103,7 +103,7 @@ func (m *Metrics) PeerActivity(peerID string) {
select {
case m.peerActivityChan <- peerID:
default:
log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID)
log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID)
}
}

View File

@@ -49,7 +49,7 @@ func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := healthcheck.NewSender()
hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx)
go p.handleHealthcheckEvents(ctx, hc)
@@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
// connection.
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
defer p.connMu.Unlock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String())
@@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
}
func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
}
// String returns the peer ID
@@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
p.log.Info("peer connection closed due healthcheck timeout")
return
case <-ctx.Done():
return
@@ -184,7 +193,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok {
p.log.Errorf("peer not found: %s", stringPeerID)
p.log.Debugf("peer not found: %s", stringPeerID)
return
}

View File

@@ -2,7 +2,6 @@ package server
import (
"context"
"crypto/sha256"
"fmt"
"net"
"net/url"
@@ -14,7 +13,9 @@ import (
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics"
)
@@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
if msgType != messages.MsgTypeHello {
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
var (
responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
if err != nil {
return nil, fmt.Errorf("unmarshal hello message: %w", err)
return nil, err
}
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, fmt.Errorf("unmarshal auth message: %w", err)
}
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
msg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
_, err = conn.Write(msg)
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}

View File

@@ -19,10 +19,14 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()]
if ok {
odlPeer.Close()
}
s.peers[peer.String()] = peer
}

View File

@@ -2,13 +2,57 @@ package server
import (
"context"
"net"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
@@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
conn := &mockConn{}
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1)
s.AddPeer(p2)

View File

@@ -300,11 +300,13 @@ install_netbird() {
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service
if ! ${SUDO} netbird service install 2>&1; then
echo "NetBird service has already been loaded"
fi
if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
if ! ${SUDO} netbird service install 2>&1; then
echo "NetBird service has already been loaded"
fi
if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
fi
fi

View File

@@ -29,12 +29,9 @@ import (
"google.golang.org/grpc/keepalive"
)
const (
metricsPort = 9090
)
var (
signalPort int
metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")

View File

@@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer)
return
}
log.Debugf("peer registered [%s]", peer.Id)
registry.metrics.ActivePeers.Add(context.Background(), 1)
// record time as milliseconds
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
@@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
peer.Id, pp.StreamID, peer.StreamID)
return
}
registry.metrics.ActivePeers.Add(context.Background(), -1)
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
}
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
}

View File

@@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
s.metrics.ActivePeers.Add(stream.Context(), 1)
return p, nil
} else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
@@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
s.metrics.ActivePeers.Add(context.Background(), -1)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {

View File

@@ -10,51 +10,30 @@ import (
log "github.com/sirupsen/logrus"
)
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
err = EnforcePermission(file)
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
return writeJson(file, obj, configDir, configFileName)
}
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
return writeJson(file, obj, configDir, configFileName)
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
}
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
}
err := os.MkdirAll(configDir, 0750)
if err != nil {
return "", "", err
}
return configDir, configFileName, err
}

View File

@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -12,6 +13,8 @@ import (
"github.com/netbirdio/netbird/formatter"
)
const defaultLogSize = 5
// InitLog parses and sets log-level input
func InitLog(logLevel string, logPath string) error {
level, err := log.ParseLevel(logLevel)
@@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err
}
customOutputs := []string{"console", "syslog"};
customOutputs := []string{"console", "syslog"}
if logPath != "" && !slices.Contains(customOutputs, logPath) {
maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath),
MaxSize: 5, // MB
MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
@@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error {
log.SetLevel(level)
return nil
}
func getLogMaxSize() int {
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
size, err := strconv.ParseInt(sizeVar, 10, 64)
if err != nil {
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
return defaultLogSize
}
log.Infof("Setting log file max size to %d MB", size)
return int(size)
}
return defaultLogSize
}

7
util/permission.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build !windows
package util
func EnforcePermission(dirPath string) error {
return nil
}

View File

@@ -0,0 +1,86 @@
package util
import (
"path/filepath"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
securityFlags = windows.OWNER_SECURITY_INFORMATION |
windows.GROUP_SECURITY_INFORMATION |
windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION
)
func EnforcePermission(file string) error {
dirPath := filepath.Dir(file)
user, group, err := sids()
if err != nil {
return err
}
adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return err
}
explicitAccess := []windows.EXPLICIT_ACCESS{
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_USER,
TrusteeValue: windows.TrusteeValueFromSID(user),
},
},
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
},
},
}
dacl, err := windows.ACLFromEntries(explicitAccess, nil)
if err != nil {
return err
}
return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
}
func sids() (*windows.SID, *windows.SID, error) {
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
if err != nil {
return nil, nil, err
}
defer func() {
if err := token.Close(); err != nil {
log.Errorf("failed to close process token: %v", err)
}
}()
tu, err := token.GetTokenUser()
if err != nil {
return nil, nil, err
}
pg, err := token.GetTokenPrimaryGroup()
if err != nil {
return nil, nil, err
}
return tu.User.Sid, pg.PrimaryGroup, nil
}