mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Compare commits
35 Commits
refactor/h
...
feature/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad8459ea2f | ||
|
|
4ebf6e1c4c | ||
|
|
1e4a0f77e2 | ||
|
|
b51d75204b | ||
|
|
e7d52c8c95 | ||
|
|
ab82302c95 | ||
|
|
d47be154ea | ||
|
|
35c892aea3 | ||
|
|
fc4b37f7bc | ||
|
|
6f0fd1d1b3 | ||
|
|
28cbb4b70f | ||
|
|
1104c9c048 | ||
|
|
5bc601111d | ||
|
|
b74951f29e | ||
|
|
97e10e440c | ||
|
|
6c50b0c84b | ||
|
|
730dd1733e | ||
|
|
82739e2832 | ||
|
|
fa7767e612 | ||
|
|
f1171198de | ||
|
|
9e041b7f82 | ||
|
|
b4c8cf0a67 | ||
|
|
1ef51a4ffa | ||
|
|
f6d57e7a96 | ||
|
|
ab892b8cf9 | ||
|
|
33c9b2d989 | ||
|
|
170e842422 | ||
|
|
4c130a0291 | ||
|
|
afb9673bc4 | ||
|
|
cf6210a6f4 | ||
|
|
c59a39d27d | ||
|
|
47adb976f8 | ||
|
|
9cfc8f8aa4 | ||
|
|
2d1bf3982d | ||
|
|
50ebbe482e |
8
.github/workflows/golang-test-darwin.yml
vendored
8
.github/workflows/golang-test-darwin.yml
vendored
@@ -18,14 +18,14 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
20
.github/workflows/golang-test-linux.yml
vendored
20
.github/workflows/golang-test-linux.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
@@ -49,18 +49,18 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
@@ -124,4 +124,4 @@ jobs:
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
6
.github/workflows/golang-test-windows.yml
vendored
6
.github/workflows/golang-test-windows.yml
vendored
@@ -17,13 +17,13 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
|
||||
10
.github/workflows/golangci-lint.yml
vendored
10
.github/workflows/golangci-lint.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
@@ -32,15 +32,15 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
@@ -49,4 +49,4 @@ jobs:
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=12m
|
||||
args: --timeout=12m
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -15,23 +15,23 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v3
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -50,11 +50,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
@@ -62,4 +62,4 @@ jobs:
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
CGO_ENABLED: 0
|
||||
|
||||
36
.github/workflows/release.yml
vendored
36
.github/workflows/release.yml
vendored
@@ -36,18 +36,18 @@ jobs:
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -93,28 +93,28 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload linux packages
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload windows packages
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload macos packages
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -133,17 +133,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -176,7 +176,7 @@ jobs:
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -189,18 +189,18 @@ jobs:
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -225,7 +225,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
|
||||
16
.github/workflows/test-infrastructure-files.yml
vendored
16
.github/workflows/test-infrastructure-files.yml
vendored
@@ -50,12 +50,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: cp setup.env
|
||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||
@@ -219,10 +219,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: handle insisting image # remove after release
|
||||
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
@@ -259,9 +256,6 @@ jobs:
|
||||
docker compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: handle insisting image gen CockroachDB # remove after release
|
||||
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
env:
|
||||
|
||||
@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
|
||||
@@ -117,6 +117,11 @@ type Config struct {
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
if configFileIsExists(configPath) {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = WriteOutConfig(input.ConfigPath, cfg)
|
||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err := util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
|
||||
@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
runningChanOpen := true
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if c.isContextCancelled() {
|
||||
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.ctx.Err() != nil {
|
||||
log.Info("Stopping Netbird Engine")
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
}
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
|
||||
c.engineMutex.Unlock()
|
||||
@@ -279,9 +286,10 @@ func (c *ConnectClient) run(
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
if runningChan != nil && runningChanOpen {
|
||||
runningChan <- nil
|
||||
close(runningChan)
|
||||
runningChanOpen = false
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
}
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (e *Engine) restartEngine() {
|
||||
log.Info("restarting engine")
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
log.Infof("cancelling client, engine will be recreated")
|
||||
e.clientCancel()
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
|
||||
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
err := fmt.Errorf("DNS server stopped")
|
||||
nsGroupStates := e.statusRecorder.GetDNSStates()
|
||||
for i := range nsGroupStates {
|
||||
nsGroupStates[i].Enabled = false
|
||||
nsGroupStates[i].Error = err
|
||||
}
|
||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
|
||||
@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
<-ctx.Done()
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Debugf("Network monitor: closed routing socket")
|
||||
log.Debugf("Network monitor: closed routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -89,8 +89,8 @@ type Conn struct {
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string, wgIP string)
|
||||
|
||||
statusRelay ConnStatus
|
||||
statusICE ConnStatus
|
||||
statusRelay *AtomicConnStatus
|
||||
statusICE *AtomicConnStatus
|
||||
currentConnPriority ConnPriority
|
||||
opened bool // this flag is used to prevent close in case of not opened connection
|
||||
|
||||
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
allowedIPsIP: allowedIPsIP.String(),
|
||||
statusRelay: StatusDisconnected,
|
||||
statusICE: StatusDisconnected,
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
iCEDisconnected: make(chan bool, 1),
|
||||
relayDisconnected: make(chan bool, 1),
|
||||
}
|
||||
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
||||
}
|
||||
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
|
||||
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
|
||||
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
}
|
||||
} else {
|
||||
if conn.statusICE == StatusDisconnected {
|
||||
if conn.statusICE.Get() == StatusDisconnected {
|
||||
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
|
||||
}
|
||||
}
|
||||
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
conn.statusICE = StatusConnected
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
|
||||
defer conn.updateIceState(iceConnInfo)
|
||||
|
||||
@@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
}
|
||||
|
||||
changed := conn.statusICE != newState && newState != StatusConnecting
|
||||
conn.statusICE = newState
|
||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||
conn.statusICE.Set(newState)
|
||||
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
@@ -518,18 +518,22 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
if err := rci.relayedConn.Close(); err != nil {
|
||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("Relay connection is ready to use")
|
||||
conn.statusRelay = StatusConnected
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
|
||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.endpointRelay = endpointUdpAddr
|
||||
@@ -538,7 +542,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
|
||||
if conn.currentConnPriority > connPriorityRelay {
|
||||
if conn.statusICE == StatusConnected {
|
||||
if conn.statusICE.Get() == StatusConnected {
|
||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
return
|
||||
}
|
||||
@@ -559,8 +563,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
@@ -594,8 +598,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
|
||||
changed := conn.statusRelay != StatusDisconnected
|
||||
conn.statusRelay = StatusDisconnected
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
@@ -661,8 +665,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
||||
}
|
||||
|
||||
func (conn *Conn) setStatusToDisconnected() {
|
||||
conn.statusRelay = StatusDisconnected
|
||||
conn.statusICE = StatusDisconnected
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.statusICE.Set(StatusDisconnected)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -706,7 +710,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
|
||||
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -718,11 +722,11 @@ func (conn *Conn) isRelayed() bool {
|
||||
}
|
||||
|
||||
func (conn *Conn) evalStatus() ConnStatus {
|
||||
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
|
||||
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
|
||||
return StatusConnected
|
||||
}
|
||||
|
||||
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
|
||||
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
@@ -733,12 +737,12 @@ func (conn *Conn) isConnected() bool {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
|
||||
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay != StatusConnected {
|
||||
if conn.statusRelay.Get() != StatusConnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -771,13 +775,12 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
|
||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
||||
}
|
||||
conn.log.Debugf("setup ice turn connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
err = wgProxy.CloseConn()
|
||||
if err != nil {
|
||||
conn.log.Warnf("failed to close turn proxy connection: %v", err)
|
||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package peer
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
@@ -12,7 +16,34 @@ const (
|
||||
)
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int
|
||||
type ConnStatus int32
|
||||
|
||||
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
||||
type AtomicConnStatus struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
||||
func NewAtomicConnStatus() *AtomicConnStatus {
|
||||
acs := &AtomicConnStatus{}
|
||||
acs.Set(StatusDisconnected)
|
||||
return acs
|
||||
}
|
||||
|
||||
// Get returns the current connection status
|
||||
func (acs *AtomicConnStatus) Get() ConnStatus {
|
||||
return ConnStatus(acs.status.Load())
|
||||
}
|
||||
|
||||
// Set updates the connection status
|
||||
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
||||
acs.status.Store(int32(status))
|
||||
}
|
||||
|
||||
// String returns the string representation of the current status
|
||||
func (acs *AtomicConnStatus) String() string {
|
||||
return acs.Get().String()
|
||||
}
|
||||
|
||||
func (s ConnStatus) String() string {
|
||||
switch s {
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
|
||||
|
||||
for _, table := range tables {
|
||||
t.Run(table.name, func(t *testing.T) {
|
||||
conn.statusICE = table.statusIce
|
||||
conn.statusRelay = table.statusRelay
|
||||
si := NewAtomicConnStatus()
|
||||
si.Set(table.statusIce)
|
||||
conn.statusICE = si
|
||||
|
||||
sr := NewAtomicConnStatus()
|
||||
sr.Set(table.statusRelay)
|
||||
conn.statusRelay = sr
|
||||
|
||||
got := conn.Status()
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
|
||||
@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
d.rosenpassPermissive,
|
||||
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return ManagementState{
|
||||
d.mgmAddress,
|
||||
d.managementState,
|
||||
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return SignalState{
|
||||
d.signalAddress,
|
||||
d.signalState,
|
||||
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if d.relayMgr == nil {
|
||||
return d.relayStates
|
||||
}
|
||||
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.nsGroupStates
|
||||
}
|
||||
|
||||
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
fullStatus := FullStatus{
|
||||
ManagementState: d.GetManagementState(),
|
||||
SignalState: d.GetSignalState(),
|
||||
LocalPeerState: d.localPeer,
|
||||
Relays: d.GetRelayStates(),
|
||||
RosenpassState: d.GetRosenpassState(),
|
||||
NSGroupStates: d.GetDNSStates(),
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
fullStatus.LocalPeerState = d.localPeer
|
||||
|
||||
for _, status := range d.peers {
|
||||
fullStatus.Peers = append(fullStatus.Peers, status)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
wgHandshakePeriod = 2 * time.Minute
|
||||
wgHandshakePeriod = 3 * time.Minute
|
||||
wgHandshakeOvertime = 30 * time.Second
|
||||
)
|
||||
|
||||
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(ctx)
|
||||
go w.wgStateCheck(ctx)
|
||||
w.ctxWgWatch = ctx
|
||||
w.ctxCancelWgWatch = ctxCancel
|
||||
|
||||
w.wgStateCheck(ctx, ctxCancel)
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) DisableWgWatcher() {
|
||||
@@ -157,37 +157,51 @@ func (w *WorkerRelay) CloseConn() {
|
||||
}
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the wireguard handshake and relay connection
|
||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
expected := wgHandshakeOvertime
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
lastHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to read wg stats: %v", err)
|
||||
continue
|
||||
}
|
||||
w.log.Tracef("last handshake: %v", lastHandshake)
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
|
||||
w.log.Debugf("WireGuard watcher started")
|
||||
lastHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read wg stats: %v", err)
|
||||
lastHandshake = time.Time{}
|
||||
}
|
||||
|
||||
if time.Since(lastHandshake) > expected {
|
||||
w.log.Infof("Wireguard handshake timed out, closing relay connection")
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
w.callBacks.OnDisconnected()
|
||||
go func(lastHandshake time.Time) {
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
defer ctxCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
handshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to read wg stats: %v", err)
|
||||
timer.Reset(wgHandshakeOvertime)
|
||||
continue
|
||||
}
|
||||
|
||||
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
||||
|
||||
if handshake.Equal(lastHandshake) {
|
||||
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
w.callBacks.OnDisconnected()
|
||||
return
|
||||
}
|
||||
|
||||
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
|
||||
lastHandshake = handshake
|
||||
timer.Reset(resetTime)
|
||||
case <-ctx.Done():
|
||||
w.log.Debugf("WireGuard watcher stopped")
|
||||
return
|
||||
}
|
||||
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
|
||||
timer.Reset(resetTime)
|
||||
expected = wgHandshakePeriod
|
||||
case <-ctx.Done():
|
||||
w.log.Debugf("WireGuard watcher stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}(lastHandshake)
|
||||
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package wgproxy
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package wgproxy
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,47 +13,49 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
ebpfManager ebpfMgr.Manager
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
lastUsedPort uint16
|
||||
localWGListenPort int
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
turnConnStore map[uint16]net.Conn
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
|
||||
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||
log.Debugf("instantiate ebpf proxy")
|
||||
wgProxy := &WGEBPFProxy{
|
||||
localWGListenPort: wgPort,
|
||||
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
||||
lastUsedPort: 0,
|
||||
turnConnStore: make(map[uint16]net.Conn),
|
||||
}
|
||||
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
|
||||
|
||||
return wgProxy
|
||||
}
|
||||
|
||||
// listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) listen() error {
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
@@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error {
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
cErr := p.Free()
|
||||
@@ -91,108 +95,112 @@ func (p *WGEBPFProxy) listen() error {
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToLocal(wgEndpointPort, turnConn)
|
||||
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
Port: int(wgEndpointPort),
|
||||
}
|
||||
return wgEndpoint, nil
|
||||
}
|
||||
|
||||
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
||||
func (p *WGEBPFProxy) CloseConn() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Free resources
|
||||
// Free resources except the remoteConns will be keep open.
|
||||
func (p *WGEBPFProxy) Free() error {
|
||||
log.Debugf("free up ebpf wg proxy")
|
||||
var err1, err2, err3 error
|
||||
if p.conn != nil {
|
||||
err1 = p.conn.Close()
|
||||
if p.ctx != nil && p.ctx.Err() != nil {
|
||||
//nolint
|
||||
return nil
|
||||
}
|
||||
|
||||
err2 = p.ebpfManager.FreeWGProxy()
|
||||
if p.rawConn != nil {
|
||||
err3 = p.rawConn.Close()
|
||||
p.ctxCancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
if err := p.ebpfManager.FreeWGProxy(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err2
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return err3
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
||||
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
||||
defer p.removeTurnConn(endpointPort)
|
||||
|
||||
var (
|
||||
err error
|
||||
n int
|
||||
)
|
||||
buf := make([]byte, 1500)
|
||||
var err error
|
||||
defer func() {
|
||||
p.removeTurnConn(endpointPort)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
var n int
|
||||
n, err = remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
for ctx.Err() == nil {
|
||||
n, err = remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
err = p.sendPkg(buf[:n], endpointPort)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
if err != io.EOF {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
||||
if ctx.Err() != nil || p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read UDP pkg from WG: %s", err)
|
||||
for p.ctx.Err() == nil {
|
||||
if err := p.readAndForwardPacket(buf); err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
|
||||
}
|
||||
log.Errorf("failed to proxy packet to remote conn: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
|
||||
}
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
if p.ctx.Err() == nil {
|
||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(buf[:n]); err != nil {
|
||||
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
@@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
delete(p.turnConnStore, turnConnID)
|
||||
|
||||
_, ok := p.turnConnStore[turnConnID]
|
||||
if ok {
|
||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||
}
|
||||
delete(p.turnConnStore, turnConnID)
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||
@@ -1,14 +1,13 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
if p != 1 {
|
||||
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
wgProxy.lastUsedPort = 65535
|
||||
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
for i := 0; i < 65535; i++ {
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
44
client/internal/wgproxy/ebpf/wrapper.go
Normal file
44
client/internal/wgproxy/ebpf/wrapper.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
||||
}
|
||||
|
||||
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||
ctxConn, cancel := context.WithCancel(ctx)
|
||||
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
||||
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
e.remoteConn = remoteConn
|
||||
e.cancel = cancel
|
||||
return addr, err
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package wgproxy
|
||||
|
||||
import "context"
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
ebpfProxy Proxy
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy(ctx context.Context) Proxy {
|
||||
if w.ebpfProxy != nil {
|
||||
return w.ebpfProxy
|
||||
}
|
||||
return NewWGUserSpaceProxy(ctx, w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
if w.ebpfProxy != nil {
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,20 +3,26 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
)
|
||||
|
||||
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
ebpfProxy *ebpf.WGEBPFProxy
|
||||
}
|
||||
|
||||
func NewFactory(userspace bool, wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
|
||||
if userspace {
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||
err := ebpfProxy.listen()
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||
err := ebpfProxy.Listen()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
if w.ebpfProxy != nil {
|
||||
p := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: w.ebpfProxy,
|
||||
}
|
||||
return p
|
||||
}
|
||||
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
|
||||
@@ -2,8 +2,20 @@
|
||||
|
||||
package wgproxy
|
||||
|
||||
import "context"
|
||||
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
|
||||
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
}
|
||||
|
||||
func NewFactory(_ bool, wgPort int) *Factory {
|
||||
return &Factory{wgPort: wgPort}
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(turnConn net.Conn) (net.Addr, error)
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
||||
CloseConn() error
|
||||
Free() error
|
||||
}
|
||||
|
||||
128
client/internal/wgproxy/proxy_test.go
Normal file
128
client/internal/wgproxy/proxy_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newMockConn() *mocConn {
|
||||
return &mocConn{
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) Read(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Write(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Close() error {
|
||||
if m.closed == true {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
close(m.closeChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mocConn) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{
|
||||
IP: net.ParseIP("172.16.254.1"),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: usp.NewWGUserSpaceProxy(51830),
|
||||
},
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
}
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
||||
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUserSpaceProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn start the proxy with the given remote conn
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||
p.remoteConn = turnConn
|
||||
|
||||
var err error
|
||||
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
|
||||
return p.localConn.LocalAddr(), err
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
p.cancel()
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
return p.localConn.Close()
|
||||
}
|
||||
|
||||
// Free doing nothing because this implementation of proxy does not have global state
|
||||
func (p *WGUserSpaceProxy) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||
// blocks
|
||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
p.cancel()
|
||||
} else {
|
||||
log.Debugf("failed to write to remote conn: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||
// blocks
|
||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
p.cancel()
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
146
client/internal/wgproxy/usp/proxy.go
Normal file
146
client/internal/wgproxy/usp/proxy.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package usp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUserSpaceProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn start the proxy with the given remote conn
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
dialer := net.Dialer{}
|
||||
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
|
||||
return p.localConn.LocalAddr(), err
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
// prevent double close
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("failed to write to remote conn: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
8
go.mod
8
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/netbirdio/netbird
|
||||
|
||||
go 1.21.0
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.4.0
|
||||
@@ -95,9 +95,10 @@ require (
|
||||
golang.org/x/term v0.21.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.3
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
||||
gorm.io/gorm v1.25.7
|
||||
nhooyr.io/websocket v1.8.11
|
||||
)
|
||||
|
||||
@@ -151,6 +152,7 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/go-text/render v0.1.0 // indirect
|
||||
github.com/go-text/typesetting v0.1.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
@@ -232,7 +234,7 @@ require (
|
||||
k8s.io/apimachinery v0.26.2 // indirect
|
||||
)
|
||||
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
|
||||
12
go.sum
12
go.sum
@@ -238,6 +238,8 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7
|
||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
@@ -523,8 +525,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk=
|
||||
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
@@ -1224,12 +1226,14 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
|
||||
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
|
||||
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
|
||||
gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g=
|
||||
gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg=
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
|
||||
@@ -56,8 +56,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
|
||||
return err
|
||||
}
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
Endpoint: endpoint,
|
||||
|
||||
@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
|
||||
return err
|
||||
}
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
PresharedKey: preSharedKey,
|
||||
|
||||
@@ -54,7 +54,7 @@ func Execute() error {
|
||||
func init() {
|
||||
stopCh = make(chan int)
|
||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
|
||||
@@ -263,6 +263,11 @@ type AccountSettings struct {
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
}
|
||||
|
||||
// Subclass used in gorm to only load network and not whole account
|
||||
type AccountNetwork struct {
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
@@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
||||
return grps
|
||||
}
|
||||
|
||||
func (a *Account) getUserGroups(userID string) ([]string, error) {
|
||||
user, err := a.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.AutoGroups, nil
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.getPeerGroups(peerID)
|
||||
enabled := true
|
||||
@@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
|
||||
return groupList
|
||||
}
|
||||
|
||||
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
|
||||
key, err := a.FindSetupKey(setupKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key.AutoGroups, nil
|
||||
}
|
||||
|
||||
func (a *Account) getTakenIPs() []net.IP {
|
||||
var takenIps []net.IP
|
||||
for _, existingPeer := range a.Peers {
|
||||
@@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
|
||||
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||
}
|
||||
|
||||
labelMap := ConvertSliceToMap(existingLabels)
|
||||
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||
}
|
||||
|
||||
if newLabel == "" {
|
||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||
}
|
||||
|
||||
return newLabel, nil
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
|
||||
@@ -6,13 +6,14 @@ import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||
|
||||
type FieldEncrypt struct {
|
||||
block cipher.Block
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
@@ -35,14 +36,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ec := &FieldEncrypt{
|
||||
block: block,
|
||||
gcm: gcm,
|
||||
}
|
||||
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) Encrypt(payload string) string {
|
||||
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
|
||||
plainText := pkcs5Padding([]byte(payload))
|
||||
cipherText := make([]byte, len(plainText))
|
||||
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
||||
@@ -50,7 +58,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
|
||||
return base64.StdEncoding.EncodeToString(cipherText)
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||
// Encrypt encrypts plaintext using AES-GCM
|
||||
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
|
||||
plaintext := []byte(payload)
|
||||
nonceSize := ec.gcm.NonceSize()
|
||||
|
||||
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -65,17 +88,49 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||
return string(payload), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-GCM
|
||||
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := ec.gcm.NonceSize()
|
||||
if len(cipherText) < nonceSize {
|
||||
return "", errors.New("cipher text too short")
|
||||
}
|
||||
|
||||
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plainText), nil
|
||||
}
|
||||
|
||||
func pkcs5Padding(ciphertext []byte) []byte {
|
||||
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padText...)
|
||||
}
|
||||
|
||||
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
||||
srcLen := len(src)
|
||||
paddingLen := int(src[srcLen-1])
|
||||
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
|
||||
return nil, fmt.Errorf("padding size error")
|
||||
if srcLen == 0 {
|
||||
return nil, errors.New("input data is empty")
|
||||
}
|
||||
|
||||
paddingLen := int(src[srcLen-1])
|
||||
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
|
||||
return nil, errors.New("invalid padding size")
|
||||
}
|
||||
|
||||
// Verify that all padding bytes are the same
|
||||
for i := 0; i < paddingLen; i++ {
|
||||
if src[srcLen-1-i] != byte(paddingLen) {
|
||||
return nil, errors.New("invalid padding")
|
||||
}
|
||||
}
|
||||
|
||||
return src[:srcLen-paddingLen], nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -15,7 +16,11 @@ func TestGenerateKey(t *testing.T) {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.Encrypt(testData)
|
||||
encrypted, err := ee.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt data: %s", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
@@ -30,6 +35,32 @@ func TestGenerateKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateKeyLegacy(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.LegacyEncrypt(testData)
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
decrypted, err := ee.LegacyDecrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decrypt data: %s", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorruptKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
@@ -41,7 +72,11 @@ func TestCorruptKey(t *testing.T) {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.Encrypt(testData)
|
||||
encrypted, err := ee.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt data: %s", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
@@ -61,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
|
||||
t.Fatalf("incorrect decryption, the result is: %s", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Generate a key for encryption/decryption
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
// Initialize the FieldEncrypt with the generated key
|
||||
ec, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create FieldEncrypt: %v", err)
|
||||
}
|
||||
|
||||
// Test cases
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Empty String",
|
||||
input: "",
|
||||
},
|
||||
{
|
||||
name: "Short String",
|
||||
input: "Hello",
|
||||
},
|
||||
{
|
||||
name: "String with Spaces",
|
||||
input: "Hello, World!",
|
||||
},
|
||||
{
|
||||
name: "Long String",
|
||||
input: "The quick brown fox jumps over the lazy dog.",
|
||||
},
|
||||
{
|
||||
name: "Unicode Characters",
|
||||
input: "こんにちは世界",
|
||||
},
|
||||
{
|
||||
name: "Special Characters",
|
||||
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
},
|
||||
{
|
||||
name: "Numeric String",
|
||||
input: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "Repeated Characters",
|
||||
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
},
|
||||
{
|
||||
name: "Multi-block String",
|
||||
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
|
||||
},
|
||||
{
|
||||
name: "Non-ASCII and ASCII Mix",
|
||||
input: "Hello 世界 123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+" - Legacy", func(t *testing.T) {
|
||||
// Legacy Encryption
|
||||
encryptedLegacy := ec.LegacyEncrypt(tc.input)
|
||||
if encryptedLegacy == "" {
|
||||
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
|
||||
}
|
||||
|
||||
// Legacy Decryption
|
||||
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
|
||||
if err != nil {
|
||||
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
|
||||
// Verify that the decrypted value matches the original input
|
||||
if decryptedLegacy != tc.input {
|
||||
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(tc.name+" - New", func(t *testing.T) {
|
||||
// New Encryption
|
||||
encryptedNew, err := ec.Encrypt(tc.input)
|
||||
if err != nil {
|
||||
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
if encryptedNew == "" {
|
||||
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
|
||||
}
|
||||
|
||||
// New Decryption
|
||||
decryptedNew, err := ec.Decrypt(encryptedNew)
|
||||
if err != nil {
|
||||
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
|
||||
// Verify that the decrypted value matches the original input
|
||||
if decryptedNew != tc.input {
|
||||
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCS5UnPadding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Padding",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Empty Input",
|
||||
input: []byte{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Zero",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Exceeds Block Size",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Exceeds Input Length",
|
||||
input: []byte{5, 5, 5},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Padding Bytes",
|
||||
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Single Byte Padding",
|
||||
input: append([]byte("Hello, World!"), byte(1)),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Invalid Mixed Padding Bytes",
|
||||
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Full Block Padding",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Non-Padding Byte at End",
|
||||
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Padding with Different Text Length",
|
||||
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
|
||||
expected: []byte("Test"),
|
||||
},
|
||||
{
|
||||
name: "Padding Length Equal to Input Length",
|
||||
input: bytes.Repeat([]byte{8}, 8),
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Invalid Padding Length Zero (Again)",
|
||||
input: append([]byte("Test"), byte(0)),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Greater Than Input",
|
||||
input: []byte{10},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Input Length Not Multiple of Block Size",
|
||||
input: append([]byte("Invalid Length"), byte(1)),
|
||||
expected: []byte("Invalid Length"),
|
||||
},
|
||||
{
|
||||
name: "Valid Padding with Non-ASCII Characters",
|
||||
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
|
||||
expected: []byte("こんにちは"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs5UnPadding(tt.input)
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Did not expect error but got: %v", err)
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("Expected output %v, got %v", tt.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
157
management/server/activity/sqlite/migration.go
Normal file
157
management/server/activity/sqlite/migration.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
|
||||
if _, err := db.Exec(createTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := updateDeletedUsersTable(ctx, db); err != nil {
|
||||
return fmt.Errorf("failed to update deleted_users table: %v", err)
|
||||
}
|
||||
|
||||
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
|
||||
}
|
||||
|
||||
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
|
||||
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
||||
exists, err := checkColumnExists(db, "deleted_users", "name")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
|
||||
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
|
||||
}
|
||||
|
||||
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
|
||||
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
|
||||
// legacy CBC encryption with a static IV to the new GCM encryption method.
|
||||
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
|
||||
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute select query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare update statement: %v", err)
|
||||
}
|
||||
defer updateStmt.Close()
|
||||
|
||||
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %v", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
|
||||
return nil
|
||||
}
|
||||
|
||||
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
|
||||
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, decryptedEmail, decryptedName string
|
||||
email, name *string
|
||||
)
|
||||
|
||||
err := rows.Scan(&id, &email, &name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if email != nil {
|
||||
decryptedEmail, err = crypt.LegacyDecrypt(*email)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
|
||||
id,
|
||||
fmt.Errorf("failed to decrypt email: %w", err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if name != nil {
|
||||
decryptedName, err = crypt.LegacyDecrypt(*name)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
|
||||
id,
|
||||
fmt.Errorf("failed to decrypt name: %w", err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt email: %w", err)
|
||||
}
|
||||
|
||||
encryptedName, err := crypt.Encrypt(decryptedName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt name: %w", err)
|
||||
}
|
||||
|
||||
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
84
management/server/activity/sqlite/migration_test.go
Normal file
84
management/server/activity/sqlite/migration_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupDatabase(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
require.NoError(t, err, "Failed to open database")
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
_, err = db.Exec(createTableQuery)
|
||||
require.NoError(t, err, "Failed to create events table")
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
|
||||
require.NoError(t, err, "Failed to create deleted_users table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
key, err := GenerateKey()
|
||||
require.NoError(t, err, "Failed to generate key")
|
||||
|
||||
crypt, err := NewFieldEncrypt(key)
|
||||
require.NoError(t, err, "Failed to initialize FieldEncrypt")
|
||||
|
||||
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
|
||||
legacyName := crypt.LegacyEncrypt("Test Account")
|
||||
|
||||
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
|
||||
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
|
||||
require.NoError(t, err, "Failed to insert event")
|
||||
|
||||
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
|
||||
require.NoError(t, err, "Failed to insert legacy encrypted data")
|
||||
|
||||
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
|
||||
require.NoError(t, err, "Failed to check if enc_algo column exists")
|
||||
require.False(t, colExists, "enc_algo column should not exist before migration")
|
||||
|
||||
err = migrate(context.Background(), crypt, db)
|
||||
require.NoError(t, err, "Migration failed")
|
||||
|
||||
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
|
||||
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
|
||||
require.True(t, colExists, "enc_algo column should exist after migration")
|
||||
|
||||
var encAlgo string
|
||||
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
|
||||
require.NoError(t, err, "Failed to select updated data")
|
||||
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
|
||||
|
||||
store, err := createStore(crypt, db)
|
||||
require.NoError(t, err, "Failed to create store")
|
||||
|
||||
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
|
||||
require.NoError(t, err, "Failed to get events")
|
||||
|
||||
require.Len(t, events, 1, "Should have one event")
|
||||
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
|
||||
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
|
||||
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
|
||||
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
|
||||
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
|
||||
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
|
||||
}
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
"meta TEXT," +
|
||||
" target_id TEXT);"
|
||||
|
||||
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
|
||||
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
|
||||
|
||||
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
||||
FROM events
|
||||
@@ -69,10 +69,12 @@ const (
|
||||
and some selfhosted deployments might have duplicates already so we need to clean the table first.
|
||||
*/
|
||||
|
||||
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
|
||||
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
|
||||
|
||||
fallbackName = "unknown"
|
||||
fallbackEmail = "unknown@unknown.com"
|
||||
|
||||
gcmEncAlgo = "GCM"
|
||||
)
|
||||
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.Exec(createTableQuery)
|
||||
if err != nil {
|
||||
if err = migrate(ctx, crypt, db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("events database migration: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(creatTableDeletedUsersQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updateDeletedUsersTable(ctx, db)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insertStmt, err := db.Prepare(insertQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selectDescStmt, err := db.Prepare(selectDescQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selectAscStmt, err := db.Prepare(selectAscQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Store{
|
||||
db: db,
|
||||
fieldEncrypt: crypt,
|
||||
insertStatement: insertStmt,
|
||||
selectDescStatement: selectDescStmt,
|
||||
selectAscStatement: selectAscStmt,
|
||||
deleteUserStmt: deleteUserStmt,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
return createStore(crypt, db)
|
||||
}
|
||||
|
||||
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
|
||||
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
|
||||
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
||||
log.WithContext(ctx).Debugf("check deleted_users table version")
|
||||
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
|
||||
// createStore initializes and returns a new Store instance with prepared SQL statements.
|
||||
func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
|
||||
insertStmt, err := db.Prepare(insertQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selectDescStmt, err := db.Prepare(selectDescQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selectAscStmt, err := db.Prepare(selectAscQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Store{
|
||||
db: db,
|
||||
fieldEncrypt: crypt,
|
||||
insertStatement: insertStmt,
|
||||
selectDescStatement: selectDescStmt,
|
||||
selectAscStatement: selectAscStmt,
|
||||
deleteUserStmt: deleteUserStmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkColumnExists checks if a column exists in a specified table
|
||||
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
|
||||
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to query table info: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
found := false
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
cid int
|
||||
name string
|
||||
dataType string
|
||||
notNull int
|
||||
dfltVal sql.NullString
|
||||
pk int
|
||||
)
|
||||
err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk)
|
||||
var cid int
|
||||
var name, ctype string
|
||||
var notnull, pk int
|
||||
var dfltValue sql.NullString
|
||||
|
||||
err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
if name == "name" {
|
||||
found = true
|
||||
break
|
||||
|
||||
if name == columnName {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return err
|
||||
if err = rows.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if found {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("update delted_users table")
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||
return err
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
type MockStore struct {
|
||||
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
|
||||
return s.account, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("account not found")
|
||||
return nil, status.NewPeerNotFoundError(peerId)
|
||||
}
|
||||
|
||||
type MocAccountManager struct {
|
||||
|
||||
@@ -2,6 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -46,6 +48,158 @@ type FileStore struct {
|
||||
metrics telemetry.AppMetrics `json:"-"`
|
||||
}
|
||||
|
||||
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
|
||||
return f(s)
|
||||
}
|
||||
|
||||
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
|
||||
if !ok {
|
||||
return status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.SetupKeys[setupKeyID].UsedTimes++
|
||||
|
||||
return s.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil || allGroup == nil {
|
||||
return errors.New("all group not found")
|
||||
}
|
||||
|
||||
allGroup.Peers = append(allGroup.Peers, peerID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, ok := s.Accounts[peer.AccountID]
|
||||
if !ok {
|
||||
return status.NewAccountNotFoundError(peer.AccountID)
|
||||
}
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
return s.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, ok := s.Accounts[accountId]
|
||||
if !ok {
|
||||
return status.NewAccountNotFoundError(accountId)
|
||||
}
|
||||
|
||||
account.Network.Serial++
|
||||
|
||||
return s.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
|
||||
if !ok {
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setupKey, ok := account.SetupKeys[key]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var takenIps []net.IP
|
||||
for _, existingPeer := range account.Peers {
|
||||
takenIps = append(takenIps, existingPeer.IP)
|
||||
}
|
||||
|
||||
return takenIps, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingLabels := []string{}
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel != "" {
|
||||
existingLabels = append(existingLabels, peer.DNSLabel)
|
||||
}
|
||||
}
|
||||
return existingLabels, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Network, nil
|
||||
}
|
||||
|
||||
type StoredAccount struct{}
|
||||
|
||||
// NewFileStore restores a store from the file located in the datadir
|
||||
@@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
|
||||
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||
@@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||
account, ok := s.Accounts[accountID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
@@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
return "", status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
|
||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
|
||||
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
||||
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ components:
|
||||
- name
|
||||
- ssh_enabled
|
||||
- login_expiration_enabled
|
||||
PeerBase:
|
||||
Peer:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerMinimum'
|
||||
- type: object
|
||||
@@ -378,25 +378,40 @@ components:
|
||||
description: User ID of the user that enrolled this peer
|
||||
type: string
|
||||
example: google-oauth2|277474792786460067937
|
||||
os:
|
||||
description: Peer's operating system and version
|
||||
type: string
|
||||
example: linux
|
||||
country_code:
|
||||
$ref: '#/components/schemas/CountryCode'
|
||||
city_name:
|
||||
$ref: '#/components/schemas/CityName'
|
||||
geoname_id:
|
||||
description: Unique identifier from the GeoNames database for a specific geographical location.
|
||||
type: integer
|
||||
example: 2643743
|
||||
connected:
|
||||
description: Peer to Management connection status
|
||||
type: boolean
|
||||
example: true
|
||||
last_seen:
|
||||
description: Last time peer connected to Netbird's management service
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2023-05-05T10:05:26.420578Z"
|
||||
required:
|
||||
- ip
|
||||
- dns_label
|
||||
- user_id
|
||||
Peer:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerBase'
|
||||
- type: object
|
||||
properties:
|
||||
accessible_peers:
|
||||
description: List of accessible peers
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/AccessiblePeer'
|
||||
required:
|
||||
- accessible_peers
|
||||
- os
|
||||
- country_code
|
||||
- city_name
|
||||
- geoname_id
|
||||
- connected
|
||||
- last_seen
|
||||
PeerBatch:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerBase'
|
||||
- $ref: '#/components/schemas/Peer'
|
||||
- type: object
|
||||
properties:
|
||||
accessible_peers_count:
|
||||
@@ -935,7 +950,7 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
|
||||
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
|
||||
action:
|
||||
description: Action to take upon policy match
|
||||
type: string
|
||||
@@ -1806,6 +1821,38 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/peers/{peerId}/accessible-peers:
|
||||
get:
|
||||
summary: List accessible Peers
|
||||
description: Returns a list of peers that the specified peer can connect to within the network.
|
||||
tags: [ Peers ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: peerId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a peer
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of Accessible Peers
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/AccessiblePeer'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/setup-keys:
|
||||
get:
|
||||
summary: List all Setup Keys
|
||||
|
||||
@@ -152,18 +152,36 @@ const (
|
||||
|
||||
// AccessiblePeer defines model for AccessiblePeer.
|
||||
type AccessiblePeer struct {
|
||||
// CityName Commonly used English name of the city
|
||||
CityName CityName `json:"city_name"`
|
||||
|
||||
// Connected Peer to Management connection status
|
||||
Connected bool `json:"connected"`
|
||||
|
||||
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
|
||||
CountryCode CountryCode `json:"country_code"`
|
||||
|
||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DnsLabel string `json:"dns_label"`
|
||||
|
||||
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
|
||||
GeonameId int `json:"geoname_id"`
|
||||
|
||||
// Id Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Ip Peer's IP address
|
||||
Ip string `json:"ip"`
|
||||
|
||||
// LastSeen Last time peer connected to Netbird's management service
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
|
||||
// Name Peer's hostname
|
||||
Name string `json:"name"`
|
||||
|
||||
// Os Peer's operating system and version
|
||||
Os string `json:"os"`
|
||||
|
||||
// UserId User ID of the user that enrolled this peer
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
|
||||
|
||||
// Peer defines model for Peer.
|
||||
type Peer struct {
|
||||
// AccessiblePeers List of accessible peers
|
||||
AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
|
||||
|
||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||
ApprovalRequired bool `json:"approval_required"`
|
||||
|
||||
// CityName Commonly used English name of the city
|
||||
CityName CityName `json:"city_name"`
|
||||
|
||||
// Connected Peer to Management connection status
|
||||
Connected bool `json:"connected"`
|
||||
|
||||
// ConnectionIp Peer's public connection IP address
|
||||
ConnectionIp string `json:"connection_ip"`
|
||||
|
||||
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
|
||||
CountryCode CountryCode `json:"country_code"`
|
||||
|
||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DnsLabel string `json:"dns_label"`
|
||||
|
||||
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
|
||||
GeonameId int `json:"geoname_id"`
|
||||
|
||||
// Groups Groups that the peer belongs to
|
||||
Groups []GroupMinimum `json:"groups"`
|
||||
|
||||
// Hostname Hostname of the machine
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// Id Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Ip Peer's IP address
|
||||
Ip string `json:"ip"`
|
||||
|
||||
// KernelVersion Peer's operating system kernel version
|
||||
KernelVersion string `json:"kernel_version"`
|
||||
|
||||
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
|
||||
// LastSeen Last time peer connected to Netbird's management service
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
|
||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
|
||||
// LoginExpired Indicates whether peer's login expired or not
|
||||
LoginExpired bool `json:"login_expired"`
|
||||
|
||||
// Name Peer's hostname
|
||||
Name string `json:"name"`
|
||||
|
||||
// Os Peer's operating system and version
|
||||
Os string `json:"os"`
|
||||
|
||||
// SerialNumber System serial number
|
||||
SerialNumber string `json:"serial_number"`
|
||||
|
||||
// SshEnabled Indicates whether SSH server is enabled on this peer
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
|
||||
// UiVersion Peer's desktop UI version
|
||||
UiVersion string `json:"ui_version"`
|
||||
|
||||
// UserId User ID of the user that enrolled this peer
|
||||
UserId string `json:"user_id"`
|
||||
|
||||
// Version Peer's daemon or cli version
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// PeerBase defines model for PeerBase.
|
||||
type PeerBase struct {
|
||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||
ApprovalRequired bool `json:"approval_required"`
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
|
||||
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (apiHandler *apiHandler) addUsersEndpoint() {
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@@ -16,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PeersHandler is a handler that returns peers of the account
|
||||
@@ -71,12 +70,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
return
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
@@ -117,13 +112,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
return
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
@@ -220,32 +211,81 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
// If the user is regular user and does not own the peer
|
||||
// with the given peerID return an empty list
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
peer, ok := account.Peers[peerID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.UserID != user.Id {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
|
||||
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
|
||||
for _, p := range netMap.Peers {
|
||||
ap := api.AccessiblePeer{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Ip: p.IP.String(),
|
||||
DnsLabel: fqdn(p, dnsDomain),
|
||||
UserId: p.UserID,
|
||||
}
|
||||
accessiblePeers = append(accessiblePeers, ap)
|
||||
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
|
||||
}
|
||||
|
||||
for _, p := range netMap.OfflinePeers {
|
||||
ap := api.AccessiblePeer{
|
||||
Id: p.ID,
|
||||
Name: p.Name,
|
||||
Ip: p.IP.String(),
|
||||
DnsLabel: fqdn(p, dnsDomain),
|
||||
UserId: p.UserID,
|
||||
}
|
||||
accessiblePeers = append(accessiblePeers, ap)
|
||||
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
|
||||
}
|
||||
|
||||
return accessiblePeers
|
||||
}
|
||||
|
||||
func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
|
||||
return api.AccessiblePeer{
|
||||
CityName: peer.Location.CityName,
|
||||
Connected: peer.Status.Connected,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Id: peer.ID,
|
||||
Ip: peer.IP.String(),
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Name: peer.Name,
|
||||
Os: peer.Meta.OS,
|
||||
UserId: peer.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsChecked := make(map[string]struct{})
|
||||
@@ -270,7 +310,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
|
||||
return groupsInfo
|
||||
}
|
||||
|
||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
|
||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
|
||||
osVersion := peer.Meta.OSVersion
|
||||
if osVersion == "" {
|
||||
osVersion = peer.Meta.Core
|
||||
@@ -296,7 +336,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.LastLogin,
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
AccessiblePeers: accessiblePeer,
|
||||
ApprovalRequired: !approved,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -12,20 +13,30 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
const testPeerID = "test_peer"
|
||||
const noUpdateChannelTestPeerID = "no-update-channel"
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
testPeerID = "test_peer"
|
||||
noUpdateChannelTestPeerID = "no-update-channel"
|
||||
|
||||
adminUser = "admin_user"
|
||||
regularUser = "regular_user"
|
||||
serviceUser = "service_user"
|
||||
userIDKey ctxKey = "user_id"
|
||||
)
|
||||
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return &PeersHandler{
|
||||
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: claims.AccountId,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
Enabled: true,
|
||||
Action: "accept",
|
||||
Destinations: []string{"group1"},
|
||||
Sources: []string{"group1"},
|
||||
Bidirectional: true,
|
||||
Protocol: "all",
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
peers[0].ID: peers[0],
|
||||
peers[1].ID: peers[1],
|
||||
},
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: claims.AccountId,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
},
|
||||
Serial: 51,
|
||||
},
|
||||
}, user, nil
|
||||
}
|
||||
|
||||
return account, account.Users[claims.UserId], nil
|
||||
},
|
||||
HasConnectedChannelFunc: func(peerID string) bool {
|
||||
statuses := make(map[string]struct{})
|
||||
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
userID := r.Context().Value(userIDKey).(string)
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
UserId: userID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessiblePeers(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
Key: "key1",
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "peer1",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
Key: "key2",
|
||||
IP: net.ParseIP("100.64.0.2"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "peer2",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: adminUser,
|
||||
}
|
||||
|
||||
peer3 := &nbpeer.Peer{
|
||||
ID: "peer3",
|
||||
Key: "key3",
|
||||
IP: net.ParseIP("100.64.0.3"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "peer3",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
p := initTestMetaData(peer1, peer2, peer3)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
peerID string
|
||||
callerUserID string
|
||||
expectedStatus int
|
||||
expectedPeers []string
|
||||
}{
|
||||
{
|
||||
name: "non admin user can access owned peer",
|
||||
peerID: "peer1",
|
||||
callerUserID: regularUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer2", "peer3"},
|
||||
},
|
||||
{
|
||||
name: "non admin user can't access unowned peer",
|
||||
peerID: "peer2",
|
||||
callerUserID: regularUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{},
|
||||
},
|
||||
{
|
||||
name: "admin user can access owned peer",
|
||||
peerID: "peer2",
|
||||
callerUserID: adminUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer3"},
|
||||
},
|
||||
{
|
||||
name: "admin user can access unowned peer",
|
||||
peerID: "peer3",
|
||||
callerUserID: adminUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
{
|
||||
name: "service user can access unowned peer",
|
||||
peerID: "peer3",
|
||||
callerUserID: serviceUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
if res.StatusCode != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response body: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var accessiblePeers []api.AccessiblePeer
|
||||
err = json.Unmarshal(body, &accessiblePeers)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
peerIDs := make([]string, len(accessiblePeers))
|
||||
for i, peer := range accessiblePeers {
|
||||
peerIDs[i] = peer.Id
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPerformance(t *testing.T) {
|
||||
if os.Getenv("CI") == "true" {
|
||||
t.Skip("Skipping on CI")
|
||||
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test on CI or Windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
// {"M", 250, 1},
|
||||
// {"L", 500, 1},
|
||||
// {"XL", 750, 1},
|
||||
{"XXL", 2000, 1},
|
||||
{"XXL", 5000, 1},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
t.Logf("management setup complete, start registering peers")
|
||||
|
||||
var counter int32
|
||||
var counterStart int32
|
||||
var wg sync.WaitGroup
|
||||
var wgAccount sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
messageCalls := []func() error{}
|
||||
for j := 0; j < bc.accounts; j++ {
|
||||
wg.Add(1)
|
||||
wgAccount.Add(1)
|
||||
var wgPeer sync.WaitGroup
|
||||
go func(j int, counter *int32, counterStart *int32) {
|
||||
defer wg.Done()
|
||||
defer wgAccount.Done()
|
||||
|
||||
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
||||
if err != nil {
|
||||
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
for i := 0; i < bc.peers; i++ {
|
||||
wgPeer.Add(1)
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Logf("failed to generate key: %v", err)
|
||||
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
mu.Lock()
|
||||
messageCalls = append(messageCalls, login)
|
||||
mu.Unlock()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt32(counterStart, 1)
|
||||
if *counterStart%100 == 0 {
|
||||
t.Logf("registered %d peers", *counterStart)
|
||||
}
|
||||
go func(peerLogin PeerLogin, counterStart *int32) {
|
||||
defer wgPeer.Done()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt32(counterStart, 1)
|
||||
if *counterStart%100 == 0 {
|
||||
t.Logf("registered %d peers", *counterStart)
|
||||
}
|
||||
}(peerLogin, counterStart)
|
||||
|
||||
}
|
||||
wgPeer.Wait()
|
||||
|
||||
t.Logf("Time for registration: %s", time.Since(startTime))
|
||||
}(j, &counter, &counterStart)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
wgAccount.Wait()
|
||||
|
||||
t.Logf("prepared %d login calls", len(messageCalls))
|
||||
testLoginPerformance(t, messageCalls)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}()
|
||||
|
||||
var account *Account
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err = account.FindPeerByPubKey(peer.Key)
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
|
||||
opEvent := &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
AccountID: account.Id,
|
||||
AccountID: accountID,
|
||||
}
|
||||
|
||||
var ephemeral bool
|
||||
setupKeyName := ""
|
||||
if !addedByUser {
|
||||
// validate the setup key if adding with a key
|
||||
sk, err := account.FindSetupKey(upperKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
var newPeer *nbpeer.Peer
|
||||
|
||||
if !sk.IsValid() {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyName = sk.Name
|
||||
} else {
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
}
|
||||
|
||||
takenIps := account.getTakenIPs()
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
|
||||
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer.DNSLabel = newLabel
|
||||
network := account.Network
|
||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: nextIp,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: newLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
var groupsToAdd []string
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
if addedByUser {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user groups: %w", err)
|
||||
}
|
||||
groupsToAdd = user.AutoGroups
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
}
|
||||
}
|
||||
// Validate the setup key
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
// add peer to 'All' group
|
||||
group, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
group.Peers = append(group.Peers, newPeer.ID)
|
||||
if !sk.IsValid() {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
var groupsToAdd []string
|
||||
if addedByUser {
|
||||
groupsToAdd, err = account.getUserGroups(userID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
groupsToAdd = sk.AutoGroups
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyID = sk.Id
|
||||
setupKeyName = sk.Name
|
||||
}
|
||||
} else {
|
||||
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, s := range groupsToAdd {
|
||||
if g, ok := account.Groups[s]; ok && g.Name != "All" {
|
||||
g.Peers = append(g.Peers, newPeer.ID)
|
||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
|
||||
|
||||
if addedByUser {
|
||||
user, err := account.FindUser(userID)
|
||||
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
|
||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
user.updateLastLogin(newPeer.LastLogin)
|
||||
}
|
||||
|
||||
account.Peers[newPeer.ID] = newPeer
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
}
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user last login: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
}
|
||||
|
||||
// Account is saved, we can release the lock
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||
}
|
||||
|
||||
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
|
||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
||||
}
|
||||
|
||||
return nextIp, nil
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}()
|
||||
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
|
||||
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||
for _, label := range existingLabels {
|
||||
labelMap[label] = struct{}{}
|
||||
}
|
||||
return labelMap
|
||||
}
|
||||
|
||||
@@ -7,20 +7,24 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
|
||||
assert.Equal(t, 1, len(response.Checks))
|
||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||
}
|
||||
|
||||
func Test_RegisterPeerByUser(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
UserID: existingUserID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
LastLogin: time.Now(),
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.UserID, existingUserID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||
|
||||
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||
|
||||
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
|
||||
}
|
||||
|
||||
func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "existingSetupKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||
|
||||
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
|
||||
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
|
||||
|
||||
}
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
SetupKey: "existingSetupKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
require.Error(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, account.Peers, newPeer.ID)
|
||||
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
|
||||
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
|
||||
|
||||
assert.Equal(t, uint64(0), account.Network.Serial)
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
|
||||
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
@@ -33,6 +35,7 @@ import (
|
||||
const (
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
)
|
||||
@@ -415,13 +418,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
var key SetupKey
|
||||
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
if key.AccountID == "" {
|
||||
@@ -474,15 +476,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.First(&user, idQueryCondition, userID)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting user from store")
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
@@ -535,7 +537,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
@@ -595,7 +597,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||
var user User
|
||||
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
@@ -612,12 +614,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@@ -631,12 +632,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
|
||||
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
||||
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@@ -650,12 +650,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@@ -677,61 +676,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
var key SetupKey
|
||||
var accountID string
|
||||
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting setup key from store")
|
||||
return "", status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("ip", &ipJSONStrings)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
|
||||
}
|
||||
|
||||
// Convert the JSON strings to net.IP objects
|
||||
ips := make([]net.IP, len(ipJSONStrings))
|
||||
for i, ipJSON := range ipJSONStrings {
|
||||
var ip net.IP
|
||||
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
|
||||
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
|
||||
}
|
||||
ips[i] = ip
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
var labels []string
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Pluck("dns_label", &labels)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
|
||||
}
|
||||
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
var accountNetwork AccountNetwork
|
||||
|
||||
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store")
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.First(&peer, "key = ?", peerKey)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
||||
}
|
||||
|
||||
return &peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
||||
}
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
|
||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue getting user from store")
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
user.LastLogin = lastLogin
|
||||
|
||||
return s.db.Save(user).Error
|
||||
return s.db.Save(&user).Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
@@ -790,6 +845,16 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewMysqlStore creates a new MySql store.
|
||||
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(mysql.Open(dsn), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(ctx, db, MySqlStoreEngine, metrics)
|
||||
}
|
||||
|
||||
func getGormConfig() *gorm.Config {
|
||||
return &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
@@ -807,6 +872,15 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store,
|
||||
return NewPostgresqlStore(ctx, dsn, metrics)
|
||||
}
|
||||
|
||||
// newMySqlStore initializes a new MySql store.
|
||||
func newMySqlStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
|
||||
dsn, ok := os.LookupEnv(mySqlDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s is not set", mySqlDsnEnv)
|
||||
}
|
||||
return NewMysqlStore(ctx, dsn, metrics)
|
||||
}
|
||||
|
||||
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
|
||||
func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewSqliteStore(ctx, dataDir, metrics)
|
||||
@@ -850,3 +924,127 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
var setupKey SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError()
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
"used_times": gorm.Expr("used_times + 1"),
|
||||
"last_used": time.Now(),
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All'")
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group 'All'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group")
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerId {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
repo := s.withTx(tx)
|
||||
err := operation(repo)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
return &SqlStore{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []net.IP{}, takenIPs)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
ip1 := net.IP{1, 1, 1, 1}.To16()
|
||||
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
AccountID: existingAccountID,
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||
|
||||
}
|
||||
|
||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer2.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
ip := net.IP{100, 64, 0, 0}.To16()
|
||||
assert.Equal(t, ip, network.Net.IP)
|
||||
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
|
||||
assert.Equal(t, "", network.Dns)
|
||||
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
|
||||
assert.Equal(t, uint64(0), network.Serial)
|
||||
}
|
||||
|
||||
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
|
||||
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
|
||||
assert.Equal(t, "Default key", setupKey.Name)
|
||||
}
|
||||
|
||||
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
||||
defer store.Close(context.Background())
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, setupKey.UsedTimes)
|
||||
|
||||
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, setupKey.UsedTimes)
|
||||
|
||||
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||
}
|
||||
|
||||
@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
|
||||
func NewPeerLoginExpiredError() error {
|
||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError() error {
|
||||
return Errorf(NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||
func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
@@ -27,6 +27,15 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type LockingStrength string
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
|
||||
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
|
||||
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
|
||||
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
@@ -41,7 +50,7 @@ type Store interface {
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
@@ -60,14 +69,24 @@ type Store interface {
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
// Close should close the store persisting all unsaved data.
|
||||
Close(ctx context.Context) error
|
||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
@@ -76,8 +95,10 @@ const (
|
||||
FileStoreEngine StoreEngine = "jsonfile"
|
||||
SqliteStoreEngine StoreEngine = "sqlite"
|
||||
PostgresStoreEngine StoreEngine = "postgres"
|
||||
MySqlStoreEngine StoreEngine = "mysql"
|
||||
|
||||
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
mySqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
||||
)
|
||||
|
||||
func getStoreEngineFromEnv() StoreEngine {
|
||||
@@ -88,11 +109,12 @@ func getStoreEngineFromEnv() StoreEngine {
|
||||
}
|
||||
|
||||
value := StoreEngine(strings.ToLower(kind))
|
||||
if value == SqliteStoreEngine || value == PostgresStoreEngine {
|
||||
switch value {
|
||||
case SqliteStoreEngine, PostgresStoreEngine, MySqlStoreEngine:
|
||||
return value
|
||||
default:
|
||||
return SqliteStoreEngine
|
||||
}
|
||||
|
||||
return SqliteStoreEngine
|
||||
}
|
||||
|
||||
// getStoreEngine determines the store engine to use.
|
||||
@@ -139,6 +161,9 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel
|
||||
case PostgresStoreEngine:
|
||||
log.WithContext(ctx).Info("using Postgres store engine")
|
||||
return newPostgresStore(ctx, metrics)
|
||||
case MySqlStoreEngine:
|
||||
log.WithContext(ctx).Info("using MySQL store engine")
|
||||
return newMySqlStore(ctx, metrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
|
||||
}
|
||||
|
||||
120
management/server/testdata/extended-store.json
vendored
Normal file
120
management/server/testdata/extended-store.json
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
{
|
||||
"Accounts": {
|
||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
"CreatedBy": "",
|
||||
"Domain": "test.com",
|
||||
"DomainCategory": "private",
|
||||
"IsDomainPrimaryAccount": true,
|
||||
"SetupKeys": {
|
||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
||||
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
"AccountID": "",
|
||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
"Name": "Default key",
|
||||
"Type": "reusable",
|
||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||
"Revoked": false,
|
||||
"UsedTimes": 0,
|
||||
"LastUsed": "0001-01-01T00:00:00Z",
|
||||
"AutoGroups": ["cfefqs706sqkneg59g2g"],
|
||||
"UsageLimit": 0,
|
||||
"Ephemeral": false
|
||||
},
|
||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
|
||||
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||
"AccountID": "",
|
||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
||||
"Name": "Faulty key with non existing group",
|
||||
"Type": "reusable",
|
||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
||||
"Revoked": false,
|
||||
"UsedTimes": 0,
|
||||
"LastUsed": "0001-01-01T00:00:00Z",
|
||||
"AutoGroups": ["abcd"],
|
||||
"UsageLimit": 0,
|
||||
"Ephemeral": false
|
||||
}
|
||||
},
|
||||
"Network": {
|
||||
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
||||
"Net": {
|
||||
"IP": "100.64.0.0",
|
||||
"Mask": "//8AAA=="
|
||||
},
|
||||
"Dns": "",
|
||||
"Serial": 0
|
||||
},
|
||||
"Peers": {},
|
||||
"Users": {
|
||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
"AccountID": "",
|
||||
"Role": "admin",
|
||||
"IsServiceUser": false,
|
||||
"ServiceUserName": "",
|
||||
"AutoGroups": ["cfefqs706sqkneg59g3g"],
|
||||
"PATs": {},
|
||||
"Blocked": false,
|
||||
"LastLogin": "0001-01-01T00:00:00Z"
|
||||
},
|
||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
||||
"AccountID": "",
|
||||
"Role": "user",
|
||||
"IsServiceUser": false,
|
||||
"ServiceUserName": "",
|
||||
"AutoGroups": null,
|
||||
"PATs": {
|
||||
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
||||
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||
"UserID": "",
|
||||
"Name": "",
|
||||
"HashedToken": "SoMeHaShEdToKeN",
|
||||
"ExpirationDate": "2023-02-27T00:00:00Z",
|
||||
"CreatedBy": "user",
|
||||
"CreatedAt": "2023-01-01T00:00:00Z",
|
||||
"LastUsed": "2023-02-01T00:00:00Z"
|
||||
}
|
||||
},
|
||||
"Blocked": false,
|
||||
"LastLogin": "0001-01-01T00:00:00Z"
|
||||
}
|
||||
},
|
||||
"Groups": {
|
||||
"cfefqs706sqkneg59g4g": {
|
||||
"ID": "cfefqs706sqkneg59g4g",
|
||||
"Name": "All",
|
||||
"Peers": []
|
||||
},
|
||||
"cfefqs706sqkneg59g3g": {
|
||||
"ID": "cfefqs706sqkneg59g3g",
|
||||
"Name": "AwesomeGroup1",
|
||||
"Peers": []
|
||||
},
|
||||
"cfefqs706sqkneg59g2g": {
|
||||
"ID": "cfefqs706sqkneg59g2g",
|
||||
"Name": "AwesomeGroup2",
|
||||
"Peers": []
|
||||
}
|
||||
},
|
||||
"Rules": null,
|
||||
"Policies": [],
|
||||
"Routes": null,
|
||||
"NameServerGroups": null,
|
||||
"DNSSettings": null,
|
||||
"Settings": {
|
||||
"PeerLoginExpirationEnabled": false,
|
||||
"PeerLoginExpiration": 86400000000000,
|
||||
"GroupsPropagationEnabled": false,
|
||||
"JWTGroupsEnabled": false,
|
||||
"JWTGroupsClaimName": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"InstallationID": ""
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||
)
|
||||
|
||||
const defaultDuration = 12 * time.Hour
|
||||
@@ -30,7 +32,7 @@ type TimeBasedAuthSecretsManager struct {
|
||||
turnCfg *TURNConfig
|
||||
relayCfg *Relay
|
||||
turnHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager *PeersUpdateManager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
relayCancelMap map[string]chan struct{}
|
||||
@@ -63,7 +65,11 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
|
||||
duration = defaultDuration
|
||||
}
|
||||
|
||||
mgr.relayHmacToken = auth.NewTimedHMAC(relayCfg.Secret, duration)
|
||||
hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
|
||||
var err error
|
||||
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
|
||||
log.Errorf("failed to create relay token generator: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return mgr
|
||||
@@ -76,7 +82,7 @@ func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
|
||||
}
|
||||
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate TURN token: %s", err)
|
||||
return nil, fmt.Errorf("generate TURN token: %s", err)
|
||||
}
|
||||
return (*Token)(turnToken), nil
|
||||
}
|
||||
@@ -86,11 +92,15 @@ func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
|
||||
if m.relayHmacToken == nil {
|
||||
return nil, fmt.Errorf("relay configuration is not set")
|
||||
}
|
||||
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
|
||||
relayToken, err := m.relayHmacToken.GenerateToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate relay token: %s", err)
|
||||
return nil, fmt.Errorf("generate relay token: %s", err)
|
||||
}
|
||||
return (*Token)(relayToken), nil
|
||||
|
||||
return &Token{
|
||||
Payload: string(relayToken.Payload),
|
||||
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
|
||||
@@ -200,7 +210,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
|
||||
relayToken, err := m.relayHmacToken.GenerateToken(sha256.New)
|
||||
relayToken, err := m.relayHmacToken.GenerateToken()
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
|
||||
return
|
||||
@@ -210,8 +220,8 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, pe
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
||||
Relay: &proto.RelayConfig{
|
||||
Urls: m.relayCfg.Addresses,
|
||||
TokenPayload: relayToken.Payload,
|
||||
TokenSignature: relayToken.Signature,
|
||||
TokenPayload: string(relayToken.Payload),
|
||||
TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
|
||||
},
|
||||
// omit Turns to avoid updates there
|
||||
},
|
||||
|
||||
@@ -63,7 +63,8 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
t.Errorf("expected generated relay signature not to be empty, got empty")
|
||||
}
|
||||
|
||||
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, []byte(secret))
|
||||
hashedSecret := sha256.Sum256([]byte(secret))
|
||||
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
|
||||
@@ -70,7 +70,7 @@ type User struct {
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
LastLogin time.Time
|
||||
LastLogin time.Time `gorm:"type:TIMESTAMP;null;default:null"`
|
||||
// CreatedAt records the time the user was created
|
||||
CreatedAt time.Time
|
||||
|
||||
@@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
||||
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||
}
|
||||
|
||||
func (u *User) updateLastLogin(login time.Time) {
|
||||
u.LastLogin = login
|
||||
}
|
||||
|
||||
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
||||
func (u *User) HasAdminPower() bool {
|
||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||
@@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||
|
||||
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
|
||||
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package allow
|
||||
|
||||
import "hash"
|
||||
|
||||
// Auth is a Validator that allows all connections.
|
||||
// Used this for testing purposes only.
|
||||
type Auth struct {
|
||||
}
|
||||
|
||||
func (a *Auth) Validate(func() hash.Hash, any) error {
|
||||
func (a *Auth) Validate(any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) ValidateHelloMsgType(any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||
)
|
||||
|
||||
// TokenStore is a simple in-memory store for token
|
||||
@@ -20,12 +22,18 @@ func (a *TokenStore) UpdateToken(token *Token) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t, err := marshalToken(*token)
|
||||
sig, err := base64.StdEncoding.DecodeString(token.Signature)
|
||||
if err != nil {
|
||||
log.Debugf("failed to marshal token: %s", err)
|
||||
return err
|
||||
return fmt.Errorf("decode signature: %w", err)
|
||||
}
|
||||
a.token = t
|
||||
|
||||
tok := v2.Token{
|
||||
AuthAlgo: v2.AuthAlgoHMACSHA256,
|
||||
Signature: sig,
|
||||
Payload: []byte(token.Payload),
|
||||
}
|
||||
|
||||
a.token = tok.Marshal()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,17 +18,6 @@ type Token struct {
|
||||
Signature string
|
||||
}
|
||||
|
||||
func marshalToken(token Token) ([]byte, error) {
|
||||
var buffer bytes.Buffer
|
||||
encoder := gob.NewEncoder(&buffer)
|
||||
err := encoder.Encode(token)
|
||||
if err != nil {
|
||||
log.Debugf("failed to marshal token: %s", err)
|
||||
return nil, fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func unmarshalToken(payload []byte) (Token, error) {
|
||||
var creds Token
|
||||
buffer := bytes.NewBuffer(payload)
|
||||
|
||||
40
relay/auth/hmac/v2/algo.go
Normal file
40
relay/auth/hmac/v2/algo.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthAlgoUnknown AuthAlgo = iota
|
||||
AuthAlgoHMACSHA256
|
||||
)
|
||||
|
||||
type AuthAlgo uint8
|
||||
|
||||
func (a AuthAlgo) String() string {
|
||||
switch a {
|
||||
case AuthAlgoHMACSHA256:
|
||||
return "HMAC-SHA256"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (a AuthAlgo) New() func() hash.Hash {
|
||||
switch a {
|
||||
case AuthAlgoHMACSHA256:
|
||||
return sha256.New
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a AuthAlgo) Size() int {
|
||||
switch a {
|
||||
case AuthAlgoHMACSHA256:
|
||||
return sha256.Size
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
45
relay/auth/hmac/v2/generator.go
Normal file
45
relay/auth/hmac/v2/generator.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Generator struct {
|
||||
algo func() hash.Hash
|
||||
algoType AuthAlgo
|
||||
secret []byte
|
||||
timeToLive time.Duration
|
||||
}
|
||||
|
||||
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
|
||||
algoFunc := algo.New()
|
||||
if algoFunc == nil {
|
||||
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
|
||||
}
|
||||
return &Generator{
|
||||
algo: algoFunc,
|
||||
algoType: algo,
|
||||
secret: secret,
|
||||
timeToLive: timeToLive,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *Generator) GenerateToken() (*Token, error) {
|
||||
expirationTime := time.Now().Add(g.timeToLive).Unix()
|
||||
|
||||
payload := []byte(strconv.FormatInt(expirationTime, 10))
|
||||
|
||||
h := hmac.New(g.algo, g.secret)
|
||||
h.Write(payload)
|
||||
signature := h.Sum(nil)
|
||||
|
||||
return &Token{
|
||||
AuthAlgo: g.algoType,
|
||||
Signature: signature,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
110
relay/auth/hmac/v2/hmac_test.go
Normal file
110
relay/auth/hmac/v2/hmac_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(token.Payload) == 0 {
|
||||
t.Fatalf("expected non-empty payload")
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err != nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidSignature(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err == nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := -1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err == nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidPayload(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err == nil {
|
||||
t.Fatalf("expected invalid token due to invalid payload")
|
||||
}
|
||||
}
|
||||
39
relay/auth/hmac/v2/token.go
Normal file
39
relay/auth/hmac/v2/token.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package v2
|
||||
|
||||
import "errors"
|
||||
|
||||
type Token struct {
|
||||
AuthAlgo AuthAlgo
|
||||
Signature []byte
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func (t *Token) Marshal() []byte {
|
||||
size := 1 + len(t.Signature) + len(t.Payload)
|
||||
|
||||
buf := make([]byte, size)
|
||||
|
||||
buf[0] = byte(t.AuthAlgo)
|
||||
copy(buf[1:], t.Signature)
|
||||
copy(buf[1+len(t.Signature):], t.Payload)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func UnmarshalToken(data []byte) (*Token, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("invalid token data")
|
||||
}
|
||||
|
||||
algo := AuthAlgo(data[0])
|
||||
sigSize := algo.Size()
|
||||
if len(data) < 1+sigSize {
|
||||
return nil, errors.New("invalid token data: insufficient length")
|
||||
}
|
||||
|
||||
return &Token{
|
||||
AuthAlgo: algo,
|
||||
Signature: data[1 : 1+sigSize],
|
||||
Payload: data[1+sigSize:],
|
||||
}, nil
|
||||
}
|
||||
59
relay/auth/hmac/v2/validator.go
Normal file
59
relay/auth/hmac/v2/validator.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const minLengthUnixTimestamp = 10
|
||||
|
||||
type Validator struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func NewValidator(secret []byte) *Validator {
|
||||
return &Validator{secret: secret}
|
||||
}
|
||||
|
||||
func (v *Validator) Validate(data any) error {
|
||||
d, ok := data.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid data type")
|
||||
}
|
||||
|
||||
token, err := UnmarshalToken(d)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal token: %w", err)
|
||||
}
|
||||
|
||||
if len(token.Payload) < minLengthUnixTimestamp {
|
||||
return errors.New("invalid payload: insufficient length")
|
||||
}
|
||||
|
||||
hashFunc := token.AuthAlgo.New()
|
||||
if hashFunc == nil {
|
||||
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
|
||||
}
|
||||
|
||||
h := hmac.New(hashFunc, v.secret)
|
||||
h.Write(token.Payload)
|
||||
expectedMAC := h.Sum(nil)
|
||||
|
||||
if !hmac.Equal(token.Signature, expectedMAC) {
|
||||
return errors.New("invalid signature")
|
||||
}
|
||||
|
||||
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid payload: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().Unix() > timestamp {
|
||||
return fmt.Errorf("expired token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"hash"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -19,7 +19,7 @@ func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACVali
|
||||
}
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) error {
|
||||
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||
b, ok := credentials.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid credentials type")
|
||||
@@ -29,5 +29,5 @@ func (a *TimedHMACValidator) Validate(algo func() hash.Hash, credentials any) er
|
||||
log.Debugf("failed to unmarshal token: %s", err)
|
||||
return err
|
||||
}
|
||||
return a.TimedHMAC.Validate(algo, c)
|
||||
return a.TimedHMAC.Validate(sha256.New, c)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
package auth
|
||||
|
||||
import "hash"
|
||||
import (
|
||||
"time"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||
)
|
||||
|
||||
// Validator is an interface that defines the Validate method.
|
||||
type Validator interface {
|
||||
Validate(func() hash.Hash, any) error
|
||||
Validate(any) error
|
||||
// Deprecated: Use Validate instead.
|
||||
ValidateHelloMsgType(any) error
|
||||
}
|
||||
|
||||
type TimedHMACValidator struct {
|
||||
authenticatorV2 *authv2.Validator
|
||||
authenticator *auth.TimedHMACValidator
|
||||
}
|
||||
|
||||
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
|
||||
return &TimedHMACValidator{
|
||||
authenticatorV2: authv2.NewValidator(secret),
|
||||
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||
return a.authenticatorV2.Validate(credentials)
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
|
||||
return a.authenticator.Validate(credentials)
|
||||
}
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
auth2 "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -60,37 +58,65 @@ func (m *Msg) Free() {
|
||||
m.bufPool.Put(m.bufPtr)
|
||||
}
|
||||
|
||||
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
|
||||
// server and forwarding them to the upper layer content reader.
|
||||
type connContainer struct {
|
||||
log *log.Entry
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
msgChanLock sync.Mutex
|
||||
closed bool // flag to check if channel is closed
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &connContainer{
|
||||
log: log,
|
||||
conn: conn,
|
||||
messages: messages,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) writeMsg(msg Msg) {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
|
||||
if cc.closed {
|
||||
msg.Free()
|
||||
return
|
||||
}
|
||||
cc.messages <- msg
|
||||
|
||||
select {
|
||||
case cc.messages <- msg:
|
||||
case <-cc.ctx.Done():
|
||||
msg.Free()
|
||||
default:
|
||||
msg.Free()
|
||||
cc.log.Infof("message queue is full")
|
||||
// todo consider to close the connection
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) close() {
|
||||
cc.cancel()
|
||||
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
close(cc.messages)
|
||||
|
||||
cc.closed = true
|
||||
close(cc.messages)
|
||||
|
||||
for msg := range cc.messages {
|
||||
msg.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
@@ -122,8 +148,8 @@ type Client struct {
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
c := &Client{
|
||||
log: log.WithFields(log.Fields{"relay": serverURL}),
|
||||
parentCtx: ctx,
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
@@ -136,11 +162,13 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
||||
},
|
||||
conns: make(map[string]*connContainer),
|
||||
}
|
||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect() error {
|
||||
c.log.Infof("connecting to relay server: %s", c.connectionURL)
|
||||
c.log.Infof("connecting to relay server")
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
@@ -161,7 +189,7 @@ func (c *Client) Connect() error {
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
|
||||
c.log.Infof("relay connection established with: %s", c.connectionURL)
|
||||
c.log.Infof("relay connection established")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -183,11 +211,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
msgChannel := make(chan Msg, 2)
|
||||
c.log.Infof("open connection to peer: %s", hashedStringID)
|
||||
msgChannel := make(chan Msg, 100)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||
|
||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
||||
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -231,7 +259,7 @@ func (c *Client) connect() error {
|
||||
if err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection: %s", cErr)
|
||||
c.log.Errorf("failed to close connection: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -240,31 +268,21 @@ func (c *Client) connect() error {
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
authMsg := &auth2.Msg{
|
||||
AuthAlgorithm: auth2.AlgoHMACSHA256,
|
||||
AdditionalData: c.authTokenStore.TokenBinary(),
|
||||
}
|
||||
|
||||
authData, err := authMsg.Marshal()
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal auth message: %w", err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloMsg(c.hashedID, authData)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal hello message: %s", err)
|
||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send hello message: %s", err)
|
||||
c.log.Errorf("failed to send auth message: %s", err)
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||
n, err := c.readWithTimeout(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read hello response: %s", err)
|
||||
c.log.Errorf("failed to read auth response: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -275,34 +293,29 @@ func (c *Client) handShake() error {
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHelloResponse {
|
||||
log.Errorf("unexpected message type: %s", msgType)
|
||||
if msgType != messages.MsgTypeAuthResponse {
|
||||
c.log.Errorf("unexpected message type: %s", msgType)
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
additionalData, err := messages.UnmarshalHelloResponse(buf[messages.SizeOfProtoHeader:n])
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addr, err := address.Unmarshal(additionalData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal address: %w", err)
|
||||
}
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = &RelayAddr{addr: addr.URL}
|
||||
c.instanceURL = &RelayAddr{addr: addr}
|
||||
c.muInstanceURL.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
internallyStoppedFlag := newInternalStopFlag()
|
||||
hc := healthcheck.NewReceiver()
|
||||
hc := healthcheck.NewReceiver(c.log)
|
||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
||||
|
||||
var (
|
||||
@@ -314,6 +327,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.log.Infof("start to Relay read loop exit")
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
@@ -360,7 +374,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypeClose:
|
||||
log.Debugf("relay connection close by server")
|
||||
c.log.Debugf("relay connection close by server")
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
@@ -429,14 +443,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal transport message: %s", err)
|
||||
c.log.Errorf("failed to marshal transport message: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write transport message: %s", err)
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
@@ -450,12 +464,15 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
||||
}
|
||||
c.log.Errorf("health check timeout")
|
||||
internalStopFlag.set()
|
||||
_ = conn.Close() // ignore the err because the readLoop will handle it
|
||||
if err := conn.Close(); err != nil {
|
||||
// ignore the err handling because the readLoop will handle it
|
||||
c.log.Warnf("failed to close connection: %s", err)
|
||||
}
|
||||
return
|
||||
case <-c.parentCtx.Done():
|
||||
err := c.close(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to teardown connection: %s", err)
|
||||
c.log.Errorf("failed to teardown connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -481,8 +498,9 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
if container.conn != connReference {
|
||||
return fmt.Errorf("conn reference mismatch")
|
||||
}
|
||||
container.close()
|
||||
c.log.Infof("free up connection to peer: %s", id)
|
||||
delete(c.conns, id)
|
||||
container.close()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -495,10 +513,12 @@ func (c *Client) close(gracefullyExit bool) error {
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.log.Warn("relay connection was already marked as not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
c.log.Infof("closing all peer connections")
|
||||
c.closeAllConns()
|
||||
if gracefullyExit {
|
||||
c.writeCloseMsg()
|
||||
@@ -506,8 +526,9 @@ func (c *Client) close(gracefullyExit bool) error {
|
||||
err = c.relayConn.Close()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.log.Infof("waiting for read loop to close")
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay connection closed with: %s", c.connectionURL)
|
||||
c.log.Infof("relay connection closed")
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseNotDrainedChannel(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Alice client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientBob.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Bob client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
// the internal channel buffer size is 2. So we should overflow it
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// wait for delivery
|
||||
time.Sleep(1 * time.Second)
|
||||
err = connBobToAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForServerToStart(errChan chan error) error {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -3,7 +3,6 @@ package client
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
@@ -17,8 +16,6 @@ import (
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
@@ -92,67 +89,23 @@ func (m *Manager) Serve() error {
|
||||
}
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||
|
||||
totalServers := len(m.serverURLs)
|
||||
|
||||
successChan := make(chan *Client, 1)
|
||||
errChan := make(chan error, len(m.serverURLs))
|
||||
|
||||
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range m.serverURLs {
|
||||
sem <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() { <-sem }()
|
||||
m.connect(m.ctx, url, successChan, errChan)
|
||||
}(url)
|
||||
sp := ServerPicker{
|
||||
TokenStore: m.tokenStore,
|
||||
PeerID: m.peerID,
|
||||
}
|
||||
|
||||
var errCount int
|
||||
|
||||
for {
|
||||
select {
|
||||
case client := <-successChan:
|
||||
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
|
||||
|
||||
m.relayClient = client
|
||||
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(client.connectionURL)
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
errCount++
|
||||
log.Warnf("Connection attempt failed: %v", err)
|
||||
if errCount == totalServers {
|
||||
return errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
client, err := sp.PickServer(m.ctx, m.serverURLs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.relayClient = client
|
||||
|
||||
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
|
||||
// TODO: abort the connection if another connection was successful
|
||||
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
|
||||
if err := relayClient.Connect(); err != nil {
|
||||
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case successChan <- relayClient:
|
||||
// This client was the first to connect successfully
|
||||
default:
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Debugf("failed to close relay client: %s", err)
|
||||
}
|
||||
}
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(client.connectionURL)
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
|
||||
98
relay/client/picker.go
Normal file
98
relay/client/picker.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
RelayClient *Client
|
||||
Url string
|
||||
Err error
|
||||
}
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(urls)
|
||||
|
||||
connResultChan := make(chan connResult, totalServers)
|
||||
successChan := make(chan connResult, 1)
|
||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range urls {
|
||||
// todo check if we have a successful connection so we do not need to connect to other servers
|
||||
concurrentLimiter <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() {
|
||||
<-concurrentLimiter
|
||||
}()
|
||||
sp.startConnection(parentCtx, connResultChan, url)
|
||||
}(url)
|
||||
}
|
||||
|
||||
go sp.processConnResults(connResultChan, successChan)
|
||||
|
||||
select {
|
||||
case cr, ok := <-successChan:
|
||||
if !ok {
|
||||
return nil, errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
log.Infof("chosen home Relay server: %s", cr.Url)
|
||||
return cr.RelayClient, nil
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
|
||||
err := relayClient.Connect()
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
Url: url,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
||||
var hasSuccess bool
|
||||
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||
cr := <-resultChan
|
||||
if cr.Err != nil {
|
||||
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
continue
|
||||
}
|
||||
log.Infof("connected to Relay server: %s", cr.Url)
|
||||
|
||||
if hasSuccess {
|
||||
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
|
||||
if err := cr.RelayClient.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
hasSuccess = true
|
||||
successChan <- cr
|
||||
}
|
||||
close(successChan)
|
||||
}
|
||||
31
relay/client/picker_test.go
Normal file
31
relay/client/picker_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
sp := ServerPicker{
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
t.Errorf("PickServer() took too long to complete")
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -16,21 +17,18 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
metricsPort = 9090
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ListenAddress string
|
||||
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
|
||||
// it is a domain:port or ip:port
|
||||
ExposedAddress string
|
||||
MetricsPort int
|
||||
LetsencryptEmail string
|
||||
LetsencryptDataDir string
|
||||
LetsencryptDomains []string
|
||||
@@ -79,6 +77,7 @@ func init() {
|
||||
cobraConfig = &Config{}
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
|
||||
rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
|
||||
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
|
||||
@@ -115,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to initialize log: %s", err)
|
||||
}
|
||||
|
||||
metricsServer, err := metrics.NewServer(metricsPort, "")
|
||||
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
||||
if err != nil {
|
||||
log.Debugf("setup metrics: %v", err)
|
||||
return fmt.Errorf("setup metrics: %v", err)
|
||||
@@ -139,7 +138,9 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
srvListenerCfg.TLSConfig = tlsConfig
|
||||
|
||||
authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour)
|
||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||
|
||||
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create relay server: %v", err)
|
||||
|
||||
@@ -3,10 +3,12 @@ package healthcheck
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
heartbeatTimeout = healthCheckInterval + 3*time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 10*time.Second
|
||||
)
|
||||
|
||||
// Receiver is a healthcheck receiver
|
||||
@@ -14,23 +16,26 @@ var (
|
||||
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
||||
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
|
||||
type Receiver struct {
|
||||
OnTimeout chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
heartbeat chan struct{}
|
||||
alive bool
|
||||
OnTimeout chan struct{}
|
||||
log *log.Entry
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
heartbeat chan struct{}
|
||||
alive bool
|
||||
attemptThreshold int
|
||||
}
|
||||
|
||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||
func NewReceiver() *Receiver {
|
||||
func NewReceiver(log *log.Entry) *Receiver {
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
r := &Receiver{
|
||||
OnTimeout: make(chan struct{}, 1),
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
OnTimeout: make(chan struct{}, 1),
|
||||
log: log,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
|
||||
go r.waitForHealthcheck()
|
||||
@@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() {
|
||||
defer r.ctxCancel()
|
||||
defer close(r.OnTimeout)
|
||||
|
||||
failureCounter := 0
|
||||
for {
|
||||
select {
|
||||
case <-r.heartbeat:
|
||||
r.alive = true
|
||||
failureCounter = 0
|
||||
case <-ticker.C:
|
||||
if r.alive {
|
||||
r.alive = false
|
||||
continue
|
||||
}
|
||||
|
||||
failureCounter++
|
||||
if failureCounter < r.attemptThreshold {
|
||||
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
|
||||
continue
|
||||
}
|
||||
r.notifyTimeout()
|
||||
return
|
||||
case <-r.ctx.Done():
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
r := NewReceiver()
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
@@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) {
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
r := NewReceiver()
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
@@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
r := NewReceiver()
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
|
||||
r.Heartbeat()
|
||||
|
||||
@@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) {
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
testsCases := []struct {
|
||||
name string
|
||||
threshold int
|
||||
resetCounterOnce bool
|
||||
}{
|
||||
{"Default attempt threshold", defaultAttemptThreshold, false},
|
||||
{"Custom attempt threshold", 3, false},
|
||||
{"Should reset threshold once", 2, true},
|
||||
}
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := heartbeatTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||
defer func() {
|
||||
healthCheckInterval = originalInterval
|
||||
heartbeatTimeout = originalTimeout
|
||||
}()
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
|
||||
receiver := NewReceiver(log.WithField("test_name", tc.name))
|
||||
|
||||
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
if tc.resetCounterOnce {
|
||||
receiver.Heartbeat()
|
||||
t.Logf("reset counter once")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiver.OnTimeout:
|
||||
if tc.resetCounterOnce {
|
||||
t.Fatalf("should not have timed out before %s", testTimeout)
|
||||
}
|
||||
case <-time.After(testTimeout):
|
||||
if tc.resetCounterOnce {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,21 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAttemptThreshold = 1
|
||||
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||
)
|
||||
|
||||
var (
|
||||
healthCheckInterval = 25 * time.Second
|
||||
healthCheckTimeout = 5 * time.Second
|
||||
healthCheckTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
// Sender is a healthcheck sender
|
||||
@@ -15,20 +24,25 @@ var (
|
||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||
// It will also stop if the context is canceled
|
||||
type Sender struct {
|
||||
log *log.Entry
|
||||
// HealthCheck is a channel to send health check signal to the peer
|
||||
HealthCheck chan struct{}
|
||||
// Timeout is a channel to the health check signal is not received in a certain time
|
||||
Timeout chan struct{}
|
||||
|
||||
ack chan struct{}
|
||||
ack chan struct{}
|
||||
alive bool
|
||||
attemptThreshold int
|
||||
}
|
||||
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender() *Sender {
|
||||
func NewSender(log *log.Entry) *Sender {
|
||||
hc := &Sender{
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
ack: make(chan struct{}, 1),
|
||||
log: log,
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
ack: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
|
||||
return hc
|
||||
@@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
ticker := time.NewTicker(healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
|
||||
defer timeoutTimer.Stop()
|
||||
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
|
||||
defer timeoutTicker.Stop()
|
||||
|
||||
defer close(hc.HealthCheck)
|
||||
defer close(hc.Timeout)
|
||||
|
||||
failureCounter := 0
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hc.HealthCheck <- struct{}{}
|
||||
case <-timeoutTimer.C:
|
||||
case <-timeoutTicker.C:
|
||||
if hc.alive {
|
||||
hc.alive = false
|
||||
continue
|
||||
}
|
||||
|
||||
failureCounter++
|
||||
if failureCounter < hc.attemptThreshold {
|
||||
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
|
||||
continue
|
||||
}
|
||||
hc.Timeout <- struct{}{}
|
||||
return
|
||||
case <-hc.ack:
|
||||
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
|
||||
failureCounter = 0
|
||||
hc.alive = true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *Sender) getTimeoutTime() time.Duration {
|
||||
return healthCheckInterval + healthCheckTimeout
|
||||
}
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
return int(threshold)
|
||||
}
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
|
||||
func TestNewHealthPeriod(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
@@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) {
|
||||
func TestNewHealthFailed(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
select {
|
||||
@@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) {
|
||||
|
||||
func TestNewHealthcheckStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
hc := NewSender()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
|
||||
func TestTimeoutReset(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
@@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) {
|
||||
t.Fatalf("is not exited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
testsCases := []struct {
|
||||
name string
|
||||
threshold int
|
||||
resetCounterOnce bool
|
||||
}{
|
||||
{"Default attempt threshold", defaultAttemptThreshold, false},
|
||||
{"Custom attempt threshold", 3, false},
|
||||
{"Should reset threshold once", 2, true},
|
||||
}
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := healthCheckTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
healthCheckTimeout = 500 * time.Millisecond
|
||||
defer func() {
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
}()
|
||||
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sender := NewSender(log.WithField("test_name", tc.name))
|
||||
go sender.StartHealthCheck(ctx)
|
||||
|
||||
go func() {
|
||||
responded := false
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case _, ok := <-sender.HealthCheck:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if tc.resetCounterOnce && !responded {
|
||||
responded = true
|
||||
sender.OnHCResponse()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
select {
|
||||
case <-sender.Timeout:
|
||||
if tc.resetCounterOnce {
|
||||
t.Fatalf("should not have timed out before %s", testTimeout)
|
||||
}
|
||||
case <-time.After(testTimeout):
|
||||
if tc.resetCounterOnce {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//nolint:tenv
|
||||
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
} else {
|
||||
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||
}
|
||||
|
||||
result := getAttemptThresholdFromEnv()
|
||||
if result != tt.expected {
|
||||
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||
package address
|
||||
|
||||
import (
|
||||
@@ -18,13 +19,3 @@ func (addr *Address) Marshal() ([]byte, error) {
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte) (*Address, error) {
|
||||
var addr Address
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
if err := dec.Decode(&addr); err != nil {
|
||||
return nil, fmt.Errorf("decode Address: %w", err)
|
||||
}
|
||||
return &addr, nil
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||
package auth
|
||||
|
||||
import (
|
||||
@@ -30,15 +31,6 @@ type Msg struct {
|
||||
AdditionalData []byte
|
||||
}
|
||||
|
||||
func (msg *Msg) Marshal() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(msg); err != nil {
|
||||
return nil, fmt.Errorf("encode Msg: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func UnmarshalMsg(data []byte) (*Msg, error) {
|
||||
var msg *Msg
|
||||
|
||||
|
||||
@@ -7,12 +7,21 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
MsgTypeUnknown MsgType = 0
|
||||
MsgTypeHello MsgType = 1
|
||||
MaxHandshakeSize = 212
|
||||
MaxHandshakeRespSize = 8192
|
||||
|
||||
CurrentProtocolVersion = 1
|
||||
|
||||
MsgTypeUnknown MsgType = 0
|
||||
// Deprecated: Use MsgTypeAuth instead.
|
||||
MsgTypeHello MsgType = 1
|
||||
// Deprecated: Use MsgTypeAuthResponse instead.
|
||||
MsgTypeHelloResponse MsgType = 2
|
||||
MsgTypeTransport MsgType = 3
|
||||
MsgTypeClose MsgType = 4
|
||||
MsgTypeHealthCheck MsgType = 5
|
||||
MsgTypeAuth = 6
|
||||
MsgTypeAuthResponse = 7
|
||||
|
||||
SizeOfVersionByte = 1
|
||||
SizeOfMsgType = 1
|
||||
@@ -22,12 +31,12 @@ const (
|
||||
sizeOfMagicByte = 4
|
||||
|
||||
headerSizeTransport = IDSize
|
||||
|
||||
headerSizeHello = sizeOfMagicByte + IDSize
|
||||
headerSizeHelloResp = 0
|
||||
|
||||
MaxHandshakeSize = 8192
|
||||
|
||||
CurrentProtocolVersion = 1
|
||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||
headerSizeAuthResp = 0
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -47,6 +56,10 @@ func (m MsgType) String() string {
|
||||
return "hello"
|
||||
case MsgTypeHelloResponse:
|
||||
return "hello response"
|
||||
case MsgTypeAuth:
|
||||
return "auth"
|
||||
case MsgTypeAuthResponse:
|
||||
return "auth response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
case MsgTypeClose:
|
||||
@@ -58,10 +71,6 @@ func (m MsgType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
type HelloResponse struct {
|
||||
InstanceAddress string
|
||||
}
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < SizeOfVersionByte {
|
||||
@@ -84,6 +93,7 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
MsgTypeAuth,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck:
|
||||
@@ -103,6 +113,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
MsgTypeAuthResponse,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck:
|
||||
@@ -112,6 +123,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthMsg instead.
|
||||
// MarshalHelloMsg initial hello message
|
||||
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
@@ -135,6 +147,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Deprecated: Use UnmarshalAuthMsg instead.
|
||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||
// authenticate the client with the server.
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
@@ -148,6 +161,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthResponse instead.
|
||||
// MarshalHelloResponse creates a response message to the hello message.
|
||||
// In case of success connection the server response with a Hello Response message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
@@ -163,6 +177,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < headerSizeHelloResp {
|
||||
@@ -171,6 +186,69 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// MarshalAuthMsg initial authentication message
|
||||
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||
// close the network connection without any response.
|
||||
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuth)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, authPayload...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeAuth {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
|
||||
}
|
||||
|
||||
// MarshalAuthResponse creates a response message to the auth.
|
||||
// In case of success connection the server response with a AuthResponse message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
ab := []byte(address)
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuthResponse)
|
||||
|
||||
msg = append(msg, ab...)
|
||||
|
||||
if len(msg) > MaxHandshakeRespSize {
|
||||
return nil, fmt.Errorf("invalid message length: %d", len(msg))
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
if len(msg) < headerSizeAuthResp+1 {
|
||||
return "", ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg), nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
|
||||
@@ -20,6 +20,22 @@ func TestMarshalHelloMsg(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalAuthMsg(peerID, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if string(receivedPeerID) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
payload := []byte("payload")
|
||||
|
||||
@@ -103,7 +103,7 @@ func (m *Metrics) PeerActivity(peerID string) {
|
||||
select {
|
||||
case m.peerActivityChan <- peerID:
|
||||
default:
|
||||
log.Errorf("peer activity channel is full, dropping activity metrics for peer %s", peerID)
|
||||
log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ func (p *Peer) Work() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hc := healthcheck.NewSender()
|
||||
hc := healthcheck.NewSender(p.log)
|
||||
go hc.StartHealthCheck(ctx)
|
||||
go p.handleHealthcheckEvents(ctx, hc)
|
||||
|
||||
@@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
|
||||
// connection.
|
||||
func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.connMu.Lock()
|
||||
defer p.connMu.Unlock()
|
||||
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
@@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) Close() {
|
||||
p.connMu.Lock()
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the peer ID
|
||||
@@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
p.log.Info("peer connection closed due healthcheck timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -184,7 +193,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
stringPeerID := messages.HashIDToString(peerID)
|
||||
dp, ok := p.store.Peer(stringPeerID)
|
||||
if !ok {
|
||||
p.log.Errorf("peer not found: %s", stringPeerID)
|
||||
p.log.Debugf("peer not found: %s", stringPeerID)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -14,7 +13,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
@@ -168,39 +169,81 @@ func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHello {
|
||||
return nil, fmt.Errorf("invalid message type from %s", conn.RemoteAddr())
|
||||
var (
|
||||
responseMsg []byte
|
||||
peerID []byte
|
||||
)
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
case messages.MsgTypeAuth:
|
||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
peerID, authData, err := messages.UnmarshalHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
if err := r.validator.Validate(sha256.New, authMsg.AdditionalData); err != nil {
|
||||
return nil, fmt.Errorf("validate %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
_, err = conn.Write(responseMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := r.validator.Validate(authPayload); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
||||
@@ -19,10 +19,14 @@ func NewStore() *Store {
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the store
|
||||
// todo: consider to close peer conn if the peer already exists
|
||||
func (s *Store) AddPeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
odlPeer, ok := s.peers[peer.String()]
|
||||
if ok {
|
||||
odlPeer.Close()
|
||||
}
|
||||
|
||||
s.peers[peer.String()] = peer
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,57 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
}
|
||||
|
||||
func (m mockConn) Read(b []byte) (n int, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) Write(b []byte) (n int, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConn) LocalAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) RemoteAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestStore_DeletePeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
@@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
conn := &mockConn{}
|
||||
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||
|
||||
s.AddPeer(p1)
|
||||
s.AddPeer(p2)
|
||||
|
||||
@@ -300,11 +300,13 @@ install_netbird() {
|
||||
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
|
||||
|
||||
# Load and start netbird service
|
||||
if ! ${SUDO} netbird service install 2>&1; then
|
||||
echo "NetBird service has already been loaded"
|
||||
fi
|
||||
if ! ${SUDO} netbird service start 2>&1; then
|
||||
echo "NetBird service has already been started"
|
||||
if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
|
||||
if ! ${SUDO} netbird service install 2>&1; then
|
||||
echo "NetBird service has already been loaded"
|
||||
fi
|
||||
if ! ${SUDO} netbird service start 2>&1; then
|
||||
echo "NetBird service has already been started"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
|
||||
@@ -29,12 +29,9 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
const (
|
||||
metricsPort = 9090
|
||||
)
|
||||
|
||||
var (
|
||||
signalPort int
|
||||
metricsPort int
|
||||
signalLetsencryptDomain string
|
||||
signalSSLDir string
|
||||
defaultSignalSSLDir string
|
||||
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||
|
||||
func init() {
|
||||
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
|
||||
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||
|
||||
@@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
|
||||
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
|
||||
peer.Id, peer.StreamID, pp.StreamID)
|
||||
registry.Peers.Store(peer.Id, peer)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("peer registered [%s]", peer.Id)
|
||||
registry.metrics.ActivePeers.Add(context.Background(), 1)
|
||||
|
||||
// record time as milliseconds
|
||||
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
|
||||
@@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
|
||||
peer.Id, pp.StreamID, peer.StreamID)
|
||||
return
|
||||
}
|
||||
registry.metrics.ActivePeers.Add(context.Background(), -1)
|
||||
log.Debugf("peer deregistered [%s]", peer.Id)
|
||||
registry.metrics.Deregistrations.Add(context.Background(), 1)
|
||||
}
|
||||
log.Debugf("peer deregistered [%s]", peer.Id)
|
||||
|
||||
registry.metrics.Deregistrations.Add(context.Background(), 1)
|
||||
}
|
||||
|
||||
@@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
|
||||
s.registry.Register(p)
|
||||
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
|
||||
|
||||
s.metrics.ActivePeers.Add(stream.Context(), 1)
|
||||
|
||||
return p, nil
|
||||
} else {
|
||||
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
|
||||
@@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
|
||||
s.registry.Deregister(p)
|
||||
|
||||
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
|
||||
s.metrics.ActivePeers.Add(context.Background(), -1)
|
||||
}
|
||||
|
||||
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
|
||||
|
||||
87
util/file.go
87
util/file.go
@@ -10,51 +10,30 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||
// The output JSON is pretty-formatted
|
||||
func WriteJson(file string, obj interface{}) error {
|
||||
|
||||
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
||||
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// make it pretty
|
||||
bs, err := json.MarshalIndent(obj, "", " ")
|
||||
err = EnforcePermission(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||
// The output JSON is pretty-formatted
|
||||
func WriteJson(file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tempFileName := tempFile.Name()
|
||||
// closing file ops as windows doesn't allow to move it
|
||||
err = tempFile.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_, err = os.Stat(tempFileName)
|
||||
if err == nil {
|
||||
os.Remove(tempFileName)
|
||||
}
|
||||
}()
|
||||
|
||||
err = os.WriteFile(tempFileName, bs, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Rename(tempFileName, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
|
||||
|
||||
// make it pretty
|
||||
bs, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tempFileName := tempFile.Name()
|
||||
// closing file ops as windows doesn't allow to move it
|
||||
err = tempFile.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_, err = os.Stat(tempFileName)
|
||||
if err == nil {
|
||||
os.Remove(tempFileName)
|
||||
}
|
||||
}()
|
||||
|
||||
err = os.WriteFile(tempFileName, bs, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Rename(tempFileName, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openOrCreateFile(file string) (*os.File, error) {
|
||||
s, err := os.Stat(file)
|
||||
if err == nil {
|
||||
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
|
||||
}
|
||||
|
||||
err := os.MkdirAll(configDir, 0750)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return configDir, configFileName, err
|
||||
}
|
||||
|
||||
23
util/log.go
23
util/log.go
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -12,6 +13,8 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
const defaultLogSize = 5
|
||||
|
||||
// InitLog parses and sets log-level input
|
||||
func InitLog(logLevel string, logPath string) error {
|
||||
level, err := log.ParseLevel(logLevel)
|
||||
@@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error {
|
||||
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
|
||||
return err
|
||||
}
|
||||
customOutputs := []string{"console", "syslog"};
|
||||
customOutputs := []string{"console", "syslog"}
|
||||
|
||||
if logPath != "" && !slices.Contains(customOutputs, logPath) {
|
||||
maxLogSize := getLogMaxSize()
|
||||
lumberjackLogger := &lumberjack.Logger{
|
||||
// Log file absolute path, os agnostic
|
||||
Filename: filepath.ToSlash(logPath),
|
||||
MaxSize: 5, // MB
|
||||
MaxSize: maxLogSize, // MB
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30, // days
|
||||
Compress: true,
|
||||
@@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error {
|
||||
log.SetLevel(level)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getLogMaxSize() int {
|
||||
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
|
||||
size, err := strconv.ParseInt(sizeVar, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
|
||||
return defaultLogSize
|
||||
}
|
||||
|
||||
log.Infof("Setting log file max size to %d MB", size)
|
||||
|
||||
return int(size)
|
||||
}
|
||||
return defaultLogSize
|
||||
}
|
||||
|
||||
7
util/permission.go
Normal file
7
util/permission.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package util
|
||||
|
||||
func EnforcePermission(dirPath string) error {
|
||||
return nil
|
||||
}
|
||||
86
util/permission_windows.go
Normal file
86
util/permission_windows.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
securityFlags = windows.OWNER_SECURITY_INFORMATION |
|
||||
windows.GROUP_SECURITY_INFORMATION |
|
||||
windows.DACL_SECURITY_INFORMATION |
|
||||
windows.PROTECTED_DACL_SECURITY_INFORMATION
|
||||
)
|
||||
|
||||
func EnforcePermission(file string) error {
|
||||
dirPath := filepath.Dir(file)
|
||||
|
||||
user, group, err := sids()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
explicitAccess := []windows.EXPLICIT_ACCESS{
|
||||
{
|
||||
AccessPermissions: windows.GENERIC_ALL,
|
||||
AccessMode: windows.SET_ACCESS,
|
||||
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||
Trustee: windows.TRUSTEE{
|
||||
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeType: windows.TRUSTEE_IS_USER,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(user),
|
||||
},
|
||||
},
|
||||
{
|
||||
AccessPermissions: windows.GENERIC_ALL,
|
||||
AccessMode: windows.SET_ACCESS,
|
||||
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||
Trustee: windows.TRUSTEE{
|
||||
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dacl, err := windows.ACLFromEntries(explicitAccess, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
|
||||
}
|
||||
|
||||
func sids() (*windows.SID, *windows.SID, error) {
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err := token.Close(); err != nil {
|
||||
log.Errorf("failed to close process token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tu, err := token.GetTokenUser()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pg, err := token.GetTokenPrimaryGroup()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return tu.User.Sid, pg.PrimaryGroup, nil
|
||||
}
|
||||
Reference in New Issue
Block a user